diff --git a/analyze.py b/analyze.py index 8191c4d..cde16f0 100644 --- a/analyze.py +++ b/analyze.py @@ -7,21 +7,38 @@ import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("-g", "--graph", action="store_true", default=False, help="Plot graph") +parser.add_argument("-s", "--save", default="", help="Graph save location") args = parser.parse_args() graph = args.graph +save = args.save def analyze(name: str, data: List[np.ndarray]): - F, p = stats.f_oneway(*data) + #print(f"Checking if normally distributed for {name}") + #for i in range(len(data)): + # _, normal_p = stats.shapiro(data[i]) + # if normal_p > 0.05: + # print(f"\tGroup {i}: normally distributed") + # else: + # print(f"\tGroup {i}: NOT normally distributed") + + filtered_data = [] + for index, item in enumerate(data): + if len(item) > 5: + filtered_data.append(item) + else: + print(f"Data group at index {index} removed due to insufficient size ({len(item)})") + + F, p = stats.kruskal(*filtered_data) print(f"F-stats for {name}: {F}") print(f"p-value for {name}: {p}") - if p > 0.05: + if round(p, 4) > 0.05: print("statistically insignificant\n") return F, p print("statistically significant") - tukey_results = stats.tukey_hsd(*data) + tukey_results = stats.tukey_hsd(*filtered_data) print(tukey_results) return F, p @@ -43,7 +60,7 @@ def plot_violin(data, labels, Fs, ps, title): index = j * 2 + k step = 1 if index > 0 else 0.5 - axs[j, k].violinplot(data[index], showmeans=True) + axs[j, k].violinplot(data[index], showmedians=True) axs[j, k].set_title(grade_names[index]) axs[j, k].set_xlabel(title, fontweight="bold") axs[j, k].set_ylabel(grade_name_labels[index], fontweight="bold") @@ -55,11 +72,14 @@ def plot_violin(data, labels, Fs, ps, title): axs[j, k].text(0.01, 0.99, f"F-stat: {F:.2f}\np-val: {p:.4f}", ha="left", va="top", transform=axs[j, k].transAxes, fontweight="bold") - means = list([a.mean() for a in data[index]]) - for l in range(len(means)): - mean = round(means[l], 2) - axs[j, k].text(l + 1.05, mean + 0.05, f"{mean}") + medians = list([np.median(a) for a in data[index]]) + for l in range(len(medians)): + median = round(medians[l], 2) + axs[j, k].text(l + 1.05, median + 0.05, f"{median}") fig.tight_layout() fig.show() - plt.show() + if save != "": + plt.savefig(save) + else: + plt.show()