Adds printing group differences

This commit is contained in:
Daniel Svitan 2024-12-27 11:48:39 +01:00
parent 6831e847ff
commit 3ad7babcdc
2 changed files with 79 additions and 21 deletions

View File

@ -1,7 +1,9 @@
from typing import List from typing import List
import itertools
import argparse import argparse
import numpy as np import numpy as np
import pandas as pd
import scipy.stats as stats import scipy.stats as stats
import scikit_posthocs as sp import scikit_posthocs as sp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -14,35 +16,83 @@ graph = args.graph
save = args.save save = args.save
# source: mostly ChatGPT (ain't no way i'm writing this shit myself)
def analyze(name: str, data: List[np.ndarray]): def analyze(name: str, data: List[np.ndarray]):
#print(f"Checking if normally distributed for {name}")
#for i in range(len(data)):
# _, normal_p = stats.shapiro(data[i])
# if normal_p > 0.05:
# print(f"\tGroup {i}: normally distributed")
# else:
# print(f"\tGroup {i}: NOT normally distributed")
filtered_data = [] filtered_data = []
group_names = []
all_values = []
for index, item in enumerate(data): for index, item in enumerate(data):
if len(item) > 5: numeric_data = [x for x in item if isinstance(x, (int, float))]
filtered_data.append(item) if len(numeric_data) > 5:
filtered_data.append(numeric_data)
group_names.append(chr(65 + index))
all_values.extend(numeric_data)
else: else:
print(f"Data group at index {index} removed due to insufficient size ({len(item)})") print(f"Data group at index {index} removed due to insufficient size ({len(numeric_data)})")
if len(filtered_data) < 2:
print(f"Insufficient number of groups for Kruskal-Wallis test in {name}")
return None, None
# Kruskal-Wallis Test
F, p = stats.kruskal(*filtered_data) F, p = stats.kruskal(*filtered_data)
print(f"F-stats for {name}: {F}") print(f"\nF-stats for {name}: {F:.8f}")
print(f"p-value for {name}: {p}") print(f"p-value for {name}: {p:.8f}")
if round(p, 4) > 0.05: if p > 0.05:
print("statistically insignificant\n") print("statistically insignificant\n")
return F, p return F, p
print("statistically significant") print("statistically significant")
#tukey_results = stats.tukey_hsd(*filtered_data)
#print(tukey_results) # Post-Hoc Dunn Test (Bonferroni-adjusted p-values)
ps = sp.posthoc_dunn(filtered_data, val_col='Values', group_col='Group', p_adjust='bonferroni') all_ranks = stats.rankdata(all_values) # Rank all values together
print(ps) group_ranks = [all_ranks[start:start + len(group)] for start, group in
zip(np.cumsum([0] + [len(g) for g in filtered_data[:-1]]), filtered_data)]
posthoc_results = sp.posthoc_dunn(filtered_data, p_adjust='bonferroni')
# we don't really need to print this, it's contained in the big ahh table
# print("\nPost-Hoc Dunn Test Results (Bonferroni-adjusted p-values):")
# print(posthoc_results)
results = []
total_sample_size = len(all_values)
for group1, group2 in itertools.combinations(group_names, 2):
idx1 = group_names.index(group1)
idx2 = group_names.index(group2)
mean_rank_1 = np.mean(group_ranks[idx1])
mean_rank_2 = np.mean(group_ranks[idx2])
rank_diff = mean_rank_1 - mean_rank_2
n1 = len(filtered_data[idx1])
n2 = len(filtered_data[idx2])
# Effect size (Rank-Biserial Correlation)
z_stat = rank_diff / np.sqrt((n1 + n2) * (n1 * n2) / total_sample_size)
effect_size = z_stat / np.sqrt(total_sample_size)
# Mean difference
mean_diff = np.mean(filtered_data[idx1]) - np.mean(filtered_data[idx2])
# Median difference
median_diff = np.median(filtered_data[idx1]) - np.median(filtered_data[idx2])
# Post-Hoc Dunn p-value
p_value = posthoc_results.loc[idx1 + 1, idx2 + 1]
results.append({
"Group 1": group1,
"Group 2": group2,
"Effect Size": f"{effect_size:.4f}",
"Mean Difference": f"{mean_diff:.4f}",
"Median Difference": f"{median_diff:.4f}",
"Post-Hoc p-value": f"{p_value:.4f}"
})
results_df = pd.DataFrame(results, dtype="object")
print("\nSummary Table of Effect Size, Mean, and Median Differences:")
print(results_df.to_markdown(index=False, tablefmt="github", disable_numparse=True))
print("")
return F, p return F, p
@ -63,7 +113,7 @@ def plot_violin(data, labels, Fs, ps, title):
index = j * 2 + k index = j * 2 + k
step = 1 if index > 0 else 0.5 step = 1 if index > 0 else 0.5
axs[j, k].violinplot(data[index], showmedians=True) axs[j, k].violinplot(data[index], showmedians=True, showmeans=True)
axs[j, k].set_title(grade_names[index]) axs[j, k].set_title(grade_names[index])
axs[j, k].set_xlabel(title, fontweight="bold") axs[j, k].set_xlabel(title, fontweight="bold")
axs[j, k].set_ylabel(grade_name_labels[index], fontweight="bold") axs[j, k].set_ylabel(grade_name_labels[index], fontweight="bold")
@ -72,13 +122,18 @@ def plot_violin(data, labels, Fs, ps, title):
F = round(Fs[index], 2) F = round(Fs[index], 2)
p = round(ps[index], 4) p = round(ps[index], 4)
axs[j, k].text(0.01, 0.99, f"F-stat: {F:.2f}\np-val: {p:.4f}", ha="left", va="top", transform=axs[j, k].transAxes, axs[j, k].text(0.01, 0.99, f"F-stat: {F:.2f}\np-val: {p:.4f}", ha="left", va="top",
transform=axs[j, k].transAxes,
fontweight="bold") fontweight="bold")
medians = list([np.median(a) for a in data[index]]) medians = list([np.median(a) for a in data[index]])
means = list([a.mean() for a in data[index]])
for l in range(len(medians)): for l in range(len(medians)):
median = round(medians[l], 2) median = round(medians[l], 2)
axs[j, k].text(l + 1.05, median + 0.05, f"{median}") mean = round(means[l], 2)
# left - mean, right - median
axs[j, k].text(l + 1.13, median - 0.05, f"{median}")
axs[j, k].text(l + 0.77, mean - 0.05, f"{mean}")
fig.tight_layout() fig.tight_layout()
if save != "": if save != "":

View File

@ -25,6 +25,7 @@ nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127
packaging==24.2 packaging==24.2
pandas==2.2.3 pandas==2.2.3
pandas-flavor==0.6.0
patsy==1.0.1 patsy==1.0.1
pillow==11.0.0 pillow==11.0.0
pyparsing==3.2.0 pyparsing==3.2.0
@ -38,7 +39,9 @@ setuptools==75.6.0
six==1.17.0 six==1.17.0
statsmodels==0.14.4 statsmodels==0.14.4
sympy==1.13.1 sympy==1.13.1
tabulate==0.9.0
threadpoolctl==3.5.0 threadpoolctl==3.5.0
torch==2.5.1 torch==2.5.1
typing_extensions==4.12.2 typing_extensions==4.12.2
tzdata==2024.2 tzdata==2024.2
xarray==2024.11.0