✨ Adds graphing to train
This commit is contained in:
parent
c970d13b43
commit
b3519f21b0
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,10 +5,11 @@
|
||||
venv/
|
||||
__pycache__/
|
||||
|
||||
results/
|
||||
|
||||
*.zip
|
||||
*.csv
|
||||
*.npy
|
||||
|
||||
*.jasp
|
||||
results.txt
|
||||
*.pth
|
||||
|
36
train_nn.py
36
train_nn.py
@ -1,9 +1,18 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from torch.utils.data import random_split, DataLoader, TensorDataset
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="train_nn"
|
||||
)
|
||||
parser.add_argument("-g", "--graph", action="store_true", default=False, help="Graph losses")
|
||||
args = parser.parse_args()
|
||||
graph = args.graph
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
@ -76,6 +85,9 @@ loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
|
||||
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f"Epoch {epoch + 1}/{epochs}...\r", end="")
|
||||
|
||||
@ -95,6 +107,7 @@ for epoch in range(epochs):
|
||||
train_loss += loss.item() * X.size(0)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
@ -109,6 +122,7 @@ for epoch in range(epochs):
|
||||
test_loss = loss.item() * X.size(0)
|
||||
|
||||
test_loss /= len(test_dataset)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
print(
|
||||
f"Epoch {epoch + 1}/{epochs}\n"
|
||||
@ -117,6 +131,9 @@ for epoch in range(epochs):
|
||||
)
|
||||
scheduler.step(test_loss)
|
||||
|
||||
print(f"Average train loss: {sum(train_losses) / len(train_losses):.4f}")
|
||||
print(f"Average test loss: {sum(test_losses) / len(test_losses):.4f}")
|
||||
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
print("Model saved to model.pth")
|
||||
|
||||
@ -144,3 +161,22 @@ print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
|
||||
if graph:
|
||||
x = np.arange(1, epochs + 1, 1)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(x, train_losses, color="red", label="Strata trénovania")
|
||||
plt.plot(x, test_losses, color="blue", label="Strata testovania")
|
||||
|
||||
plt.xlabel("Epocha")
|
||||
plt.ylabel("Strata")
|
||||
plt.title("Priebeh trénovania")
|
||||
|
||||
plt.text(0.99, 0.99,
|
||||
f"Presnosť: {accuracy:.4f}\nPrecíznosť: {precision:.4f}\nOdvolanie: {recall:.4f}\nF1 skóre: {f1:.4f}",
|
||||
ha="right", va="top", transform=plt.gca().transAxes, fontweight="bold")
|
||||
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
Loading…
x
Reference in New Issue
Block a user