Adds graphing to train

This commit is contained in:
Daniel Svitan 2024-12-22 19:23:08 +01:00
parent c970d13b43
commit b3519f21b0
2 changed files with 38 additions and 1 deletions

3
.gitignore vendored
View File

@ -5,10 +5,11 @@
venv/
__pycache__/
results/
*.zip
*.csv
*.npy
*.jasp
results.txt
*.pth

View File

@ -1,9 +1,18 @@
import numpy as np
import torch
import torch.nn as nn
import argparse
from torch.utils.data import random_split, DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(
prog="train_nn"
)
parser.add_argument("-g", "--graph", action="store_true", default=False, help="Graph losses")
args = parser.parse_args()
graph = args.graph
class NeuralNetwork(nn.Module):
@ -76,6 +85,9 @@ loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
train_losses = []
test_losses = []
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}...\r", end="")
@ -95,6 +107,7 @@ for epoch in range(epochs):
train_loss += loss.item() * X.size(0)
train_loss /= len(train_dataset)
train_losses.append(train_loss)
model.eval()
test_loss = 0.0
@ -109,6 +122,7 @@ for epoch in range(epochs):
test_loss = loss.item() * X.size(0)
test_loss /= len(test_dataset)
test_losses.append(test_loss)
print(
f"Epoch {epoch + 1}/{epochs}\n"
@ -117,6 +131,9 @@ for epoch in range(epochs):
)
scheduler.step(test_loss)
print(f"Average train loss: {sum(train_losses) / len(train_losses):.4f}")
print(f"Average test loss: {sum(test_losses) / len(test_losses):.4f}")
torch.save(model.state_dict(), "model.pth")
print("Model saved to model.pth")
@ -144,3 +161,22 @@ print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
if graph:
x = np.arange(1, epochs + 1, 1)
plt.figure(figsize=(8, 6))
plt.plot(x, train_losses, color="red", label="Strata trénovania")
plt.plot(x, test_losses, color="blue", label="Strata testovania")
plt.xlabel("Epocha")
plt.ylabel("Strata")
plt.title("Priebeh trénovania")
plt.text(0.99, 0.99,
f"Presnosť: {accuracy:.4f}\nPrecíznosť: {precision:.4f}\nOdvolanie: {recall:.4f}\nF1 skóre: {f1:.4f}",
ha="right", va="top", transform=plt.gca().transAxes, fontweight="bold")
plt.legend()
plt.tight_layout()
plt.show()