Skip to content

Commit

Permalink
Visualization isn't useful when there are too many data points. Can w…
Browse files Browse the repository at this point in the history
…e subsample? (#1672)
  • Loading branch information
R-Palazzo authored Nov 13, 2023
1 parent ba77e45 commit 0f4fa36
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 5 deletions.
6 changes: 5 additions & 1 deletion sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name


def get_column_pair_plot(real_data, synthetic_data, metadata,
table_name, column_names, plot_type=None):
table_name, column_names, plot_type=None, sample_size=None):
"""Get a plot of the real and synthetic data for a given column pair.
Args:
Expand All @@ -107,6 +107,9 @@ def get_column_pair_plot(real_data, synthetic_data, metadata,
If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data
that the column contains, ``scatter`` used for datetime and numerical values,
``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``.
sample_size (int or None):
The number of samples to plot. If ``None``, all samples are plotted.
Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure:
Expand All @@ -120,6 +123,7 @@ def get_column_pair_plot(real_data, synthetic_data, metadata,
synthetic_data,
metadata,
column_names,
sample_size,
plot_type
)

Expand Down
11 changes: 10 additions & 1 deletion sdv/evaluation/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
)


def get_column_pair_plot(real_data, synthetic_data, metadata, column_names, plot_type=None):
def get_column_pair_plot(
real_data, synthetic_data, metadata, column_names, plot_type=None, sample_size=None):
"""Get a plot of the real and synthetic data for a given column pair.
Args:
Expand All @@ -124,6 +125,9 @@ def get_column_pair_plot(real_data, synthetic_data, metadata, column_names, plot
If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data
that the column contains, ``scatter`` used for datetime and numerical values,
``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``.
sample_size (int or None):
The number of samples to use for the plot. If ``None`` use the whole dataset.
Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure:
Expand Down Expand Up @@ -164,6 +168,11 @@ def get_column_pair_plot(real_data, synthetic_data, metadata, column_names, plot
format=datetime_format
)

require_subsample = sample_size and sample_size < min(len(real_data), len(synthetic_data))
if require_subsample:
real_data = real_data.sample(n=sample_size)
synthetic_data = synthetic_data.sample(n=sample_size)

return visualization.get_column_pair_plot(
real_data,
synthetic_data,
Expand Down
29 changes: 28 additions & 1 deletion tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

import pandas as pd

from sdv.evaluation.single_table import evaluate_quality, run_diagnostic
from sdv.datasets.demo import download_demo
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

Expand Down Expand Up @@ -29,3 +30,29 @@ def test_evaluation():
],
'WARNING': []
}


def test_column_pair_plot_sample_size_parameter():
"""Test the sample_size parameter for the column pair plot."""
# Setup
real_data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.fit(real_data)
synthetic_data = synthesizer.sample(len(real_data))

# Run
fig = get_column_pair_plot(
real_data=real_data,
synthetic_data=synthetic_data,
column_names=['room_rate', 'amenities_fee'],
metadata=metadata,
sample_size=40
)

# Assert
assert len(synthetic_data) == 500
assert len(fig.data[0].x) == 40
assert len(fig.data[1].x) == 40
4 changes: 2 additions & 2 deletions tests/unit/evaluation/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_get_column_pair_plot(mock_plot):
mock_plot.return_value = 'plot'

# Run
plot = get_column_pair_plot(data1, data2, metadata, 'table', ['col1', 'col2'])
plot = get_column_pair_plot(data1, data2, metadata, 'table', ['col1', 'col2'], 2)

# Assert
call_metadata = metadata.tables['table']
mock_plot.assert_called_once_with(table1, table2, call_metadata, ['col1', 'col2'], None)
mock_plot.assert_called_once_with(table1, table2, call_metadata, ['col1', 'col2'], None, 2)
assert plot == 'plot'


Expand Down
57 changes: 57 additions & 0 deletions tests/unit/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,60 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot):
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'heatmap'
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_pair_plot')
def test_get_column_pair_plot_with_sample_size(mock_get_plot):
"""Test ``get_column_pair_plot`` with ``sample_size`` parameter."""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({
'amount': [1, 2, 3],
'price': [10, 20, 30],
})
synthetic_data = pd.DataFrame({
'amount': [1., 2., 3.],
'price': [11., 22., 33.],
})
metadata = SingleTableMetadata()
metadata.add_column('amount', sdtype='numerical')
metadata.add_column('price', sdtype='numerical')

# Run
get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2)

# Assert
real_subsample = mock_get_plot.call_args[0][0]
synthetic_subsample = mock_get_plot.call_args[0][1]
assert len(real_subsample) == 2
assert len(synthetic_subsample) == 2
assert real_subsample.isin(real_data).all().all()
assert synthetic_subsample.isin(synthetic_data).all().all()


@patch('sdmetrics.visualization.get_column_pair_plot')
def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot):
"""Test ``get_column_pair_plot`` when ``sample_size`` is bigger than the length of the data."""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({
'amount': [1, 2, 3],
'price': [10, 20, 30],
})
synthetic_data = pd.DataFrame({
'amount': [1., 2., 3.],
'price': [11., 22., 33.],
})
metadata = SingleTableMetadata()
metadata.add_column('amount', sdtype='numerical')
metadata.add_column('price', sdtype='numerical')

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10)

# Assert
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][0], real_data)
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][1], synthetic_data)
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value

0 comments on commit 0f4fa36

Please sign in to comment.