Skip to content

Commit

Permalink
just some threshold logic tweaks on correlation matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Dec 10, 2024
1 parent ddcfda1 commit e8074d5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/plugins/pages/plugin_page_2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Plugin Page 1: A 'Hello World' SageWorks Plugin Page"""
"""Plugin Page 2: A 'Hello World' SageWorks Plugin Page"""

import dash
from dash import html, page_container, register_page
Expand Down
2 changes: 1 addition & 1 deletion examples/plugins/pages/plugin_page_3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Plugin Page 1: A 'Hello World' SageWorks Plugin Page"""
"""Plugin Page 3: A 'Hello World' SageWorks Plugin Page"""

import dash
from dash import html, page_container, register_page, callback, Output, Input, State, no_update
Expand Down
6 changes: 3 additions & 3 deletions src/sageworks/themes/dark/dark.json
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@
"diverging": [
[0.0, "rgba(80, 80, 240, 1.0)"],
[0.25, "rgba(70, 200, 200, 0.5)"],
[0.35, "rgba(150, 150, 150, 0.0)"],
[0.65, "rgba(150, 150, 150, 0.0)"],
[0.30, "rgba(70, 200, 200, 0.0)"],
[0.70, "rgba(200, 200, 100, 0.0)"],
[0.75, "rgba(200, 200, 100, 0.5)"],
[1.0, "rgba(255, 60, 80, 1.0)"]
[1.0, "rgba(120, 30, 30, 1.0)"]
],
"sequential": [
[0.0, "rgba(100, 100, 255, 1.0)"],
Expand Down
16 changes: 9 additions & 7 deletions src/sageworks/web_interface/components/correlation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def update_properties(self, data_source_details: dict) -> go.Figure:
if "column_stats" not in data_source_details:
return self.display_text("No column_stats Found", figure_height=200)

# Threshold for the correlation matrix (filter out low values)
threshold = 0.4

# Convert the data details into a correlation dataframe
df = self._corr_df_from_data_details(data_source_details)
df = self._corr_df_from_data_details(data_source_details, threshold=threshold)

# If the dataframe is empty then return a message
if df.empty:
Expand Down Expand Up @@ -83,17 +86,16 @@ def update_properties(self, data_source_details: dict) -> go.Figure:
fig.update_yaxes(tickvals=y_labels, ticktext=df.index, showgrid=False)

# Now we're going to customize the annotations and filter out low values
label_threshold = 0.3
for i, row in enumerate(df.index):
for j, col in enumerate(df.columns):
value = df.loc[row, col]
if abs(value) > label_threshold:
if abs(value) > threshold:
fig.add_annotation(x=j, y=i, text=f"{value:.2f}", showarrow=False)

return fig

@staticmethod
def _corr_df_from_data_details(data_details: dict, threshold: float = 0.3) -> pd.DataFrame:
def _corr_df_from_data_details(data_details: dict, threshold: float = 0.4) -> pd.DataFrame:
"""Internal: Create a Pandas DataFrame in the form given by df.corr() from DataSource details.
Args:
data_details (dict): A dictionary containing DataSource details.
Expand All @@ -118,12 +120,12 @@ def _corr_df_from_data_details(data_details: dict, threshold: float = 0.3) -> pd
corr_df = corr_df.loc[:, (corr_df.abs().max() > threshold)]
corr_df = corr_df[(corr_df.abs().max(axis=1) > threshold)]

# If the correlation matrix is bigger than 10x10 then we need to filter it down
while corr_df.shape[0] > 10 and threshold <= 0.6:
# If the correlation matrix is bigger than 12x12 then we need to filter it down
while corr_df.shape[0] > 12:
# Now filter out any correlations below the threshold
corr_df = corr_df.loc[:, (corr_df.abs().max() > threshold)]
corr_df = corr_df[(corr_df.abs().max(axis=1) > threshold)]
threshold += 0.1
threshold += 0.05

# Return the correlation dataframe in the form of df.corr()
corr_df.sort_index(inplace=True)
Expand Down

0 comments on commit e8074d5

Please sign in to comment.