From b3519f21b04d5e7b3d00e2abea0c55176784fb97 Mon Sep 17 00:00:00 2001 From: Daniel Svitan Date: Sun, 22 Dec 2024 19:23:08 +0100 Subject: [PATCH] :sparkles: Adds graphing to train --- .gitignore | 3 ++- train_nn.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 63a85b8..abfed3a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,10 +5,11 @@ venv/ __pycache__/ +results/ + *.zip *.csv *.npy *.jasp -results.txt *.pth diff --git a/train_nn.py b/train_nn.py index 182f16a..4f23d34 100644 --- a/train_nn.py +++ b/train_nn.py @@ -1,9 +1,18 @@ import numpy as np import torch import torch.nn as nn +import argparse from torch.utils.data import random_split, DataLoader, TensorDataset from torch.optim.lr_scheduler import ReduceLROnPlateau from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score +import matplotlib.pyplot as plt + +parser = argparse.ArgumentParser( + prog="train_nn" +) +parser.add_argument("-g", "--graph", action="store_true", default=False, help="Graph losses") +args = parser.parse_args() +graph = args.graph class NeuralNetwork(nn.Module): @@ -76,6 +85,9 @@ loss_fn = nn.MSELoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5) +train_losses = [] +test_losses = [] + for epoch in range(epochs): print(f"Epoch {epoch + 1}/{epochs}...\r", end="") @@ -95,6 +107,7 @@ for epoch in range(epochs): train_loss += loss.item() * X.size(0) train_loss /= len(train_dataset) + train_losses.append(train_loss) model.eval() test_loss = 0.0 @@ -109,6 +122,7 @@ for epoch in range(epochs): test_loss = loss.item() * X.size(0) test_loss /= len(test_dataset) + test_losses.append(test_loss) print( f"Epoch {epoch + 1}/{epochs}\n" @@ -117,6 +131,9 @@ for epoch in range(epochs): ) scheduler.step(test_loss) +print(f"Average train loss: {sum(train_losses) / len(train_losses):.4f}") +print(f"Average test loss: {sum(test_losses) / len(test_losses):.4f}") + torch.save(model.state_dict(), "model.pth") print("Model saved to model.pth") @@ -144,3 +161,22 @@ print(f"Accuracy: {accuracy:.4f}") print(f"Precision: {precision:.4f}") print(f"Recall: {recall:.4f}") print(f"F1 Score: {f1:.4f}") + +if graph: + x = np.arange(1, epochs + 1, 1) + + plt.figure(figsize=(8, 6)) + plt.plot(x, train_losses, color="red", label="Strata trénovania") + plt.plot(x, test_losses, color="blue", label="Strata testovania") + + plt.xlabel("Epocha") + plt.ylabel("Strata") + plt.title("Priebeh trénovania") + + plt.text(0.99, 0.99, + f"Presnosť: {accuracy:.4f}\nPrecíznosť: {precision:.4f}\nOdvolanie: {recall:.4f}\nF1 skóre: {f1:.4f}", + ha="right", va="top", transform=plt.gca().transAxes, fontweight="bold") + + plt.legend() + plt.tight_layout() + plt.show()