Skip to content

Commit

Permalink
fix: fixing running kruskal wallis and plotting sig syllables
Browse files Browse the repository at this point in the history
fix: fixing running kruskal wallis and plotting sig syllables
  • Loading branch information
versey-sherry authored Mar 15, 2024
2 parents d4a0ea3 + ebbd2fb commit c28ef80
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions keypoint_moseq/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def get_syllable_names(project_dir, model_name, syllable_ixs):

for ix in syllable_ixs:
if len(syll_info_df[syll_info_df.syllable == ix].label.values[0]) > 0:
labels[
ix
] = f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
labels[ix] = (
f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
)
names = [labels[ix] for ix in syllable_ixs]
return names

Expand Down Expand Up @@ -880,7 +880,10 @@ def run_kruskal(
# combine Dunn's test results into single DataFrame
df_z = pd.DataFrame(real_zs_within_group)
df_z.index = df_z.index.set_names("syllable")
dunn_results_df = df_z.reset_index().melt(id_vars="syllable")
dunn_results_df = df_z.reset_index().melt(id_vars=[("syllable", "")])
dunn_results_df.rename(
columns={"variable_0": "group1", "variable_1": "group2"}, inplace=True
)

# Get intersecting significant syllables between
intersect_sig_syllables = {}
Expand Down Expand Up @@ -1036,7 +1039,7 @@ def _validate_and_order_syll_stats_params(
if len(colors) == 0 or len(colors) != len(groups):
colors = sns.color_palette(n_colors=len(groups))

return ordering, groups, colors, figsize
return np.array(ordering), groups, colors, figsize


def save_analysis_figure(fig, plot_name, project_dir, model_name, save_dir):
Expand Down Expand Up @@ -1088,7 +1091,7 @@ def plot_syll_stats_with_sem(
the threshold for significance, by default 0.05
stat : str, optional
the statistic to plot, by default 'frequency'
ordering : str, optional
order : str, optional
the ordering of the syllables, by default 'stat'
groups : list, optional
the list of groups to plot, by default None
Expand Down Expand Up @@ -1171,8 +1174,12 @@ def plot_syll_stats_with_sem(
if sig_sylls is not None:
markings = []
for s in sig_sylls:
markings.append(ordering.index(s))
plt.scatter(markings, [-0.005] * len(markings), color="r", marker="*")
if s in ordering:
markings.append(np.where(ordering == s)[0])
else:
continue
markings = np.concatenate(markings)
plt.scatter(markings, [-0.05] * len(markings), color="r", marker="*")

# manually define a new patch
patch = Line2D(
Expand Down

0 comments on commit c28ef80

Please sign in to comment.