✨ Adds graphing to train
This commit is contained in:
parent
c970d13b43
commit
b3519f21b0
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,10 +5,11 @@
|
|||||||
venv/
|
venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
|
results/
|
||||||
|
|
||||||
*.zip
|
*.zip
|
||||||
*.csv
|
*.csv
|
||||||
*.npy
|
*.npy
|
||||||
|
|
||||||
*.jasp
|
*.jasp
|
||||||
results.txt
|
|
||||||
*.pth
|
*.pth
|
||||||
|
36
train_nn.py
36
train_nn.py
@ -1,9 +1,18 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import argparse
|
||||||
from torch.utils.data import random_split, DataLoader, TensorDataset
|
from torch.utils.data import random_split, DataLoader, TensorDataset
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="train_nn"
|
||||||
|
)
|
||||||
|
parser.add_argument("-g", "--graph", action="store_true", default=False, help="Graph losses")
|
||||||
|
args = parser.parse_args()
|
||||||
|
graph = args.graph
|
||||||
|
|
||||||
|
|
||||||
class NeuralNetwork(nn.Module):
|
class NeuralNetwork(nn.Module):
|
||||||
@ -76,6 +85,9 @@ loss_fn = nn.MSELoss()
|
|||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
||||||
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
|
||||||
|
|
||||||
|
train_losses = []
|
||||||
|
test_losses = []
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
print(f"Epoch {epoch + 1}/{epochs}...\r", end="")
|
print(f"Epoch {epoch + 1}/{epochs}...\r", end="")
|
||||||
|
|
||||||
@ -95,6 +107,7 @@ for epoch in range(epochs):
|
|||||||
train_loss += loss.item() * X.size(0)
|
train_loss += loss.item() * X.size(0)
|
||||||
|
|
||||||
train_loss /= len(train_dataset)
|
train_loss /= len(train_dataset)
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
test_loss = 0.0
|
test_loss = 0.0
|
||||||
@ -109,6 +122,7 @@ for epoch in range(epochs):
|
|||||||
test_loss = loss.item() * X.size(0)
|
test_loss = loss.item() * X.size(0)
|
||||||
|
|
||||||
test_loss /= len(test_dataset)
|
test_loss /= len(test_dataset)
|
||||||
|
test_losses.append(test_loss)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Epoch {epoch + 1}/{epochs}\n"
|
f"Epoch {epoch + 1}/{epochs}\n"
|
||||||
@ -117,6 +131,9 @@ for epoch in range(epochs):
|
|||||||
)
|
)
|
||||||
scheduler.step(test_loss)
|
scheduler.step(test_loss)
|
||||||
|
|
||||||
|
print(f"Average train loss: {sum(train_losses) / len(train_losses):.4f}")
|
||||||
|
print(f"Average test loss: {sum(test_losses) / len(test_losses):.4f}")
|
||||||
|
|
||||||
torch.save(model.state_dict(), "model.pth")
|
torch.save(model.state_dict(), "model.pth")
|
||||||
print("Model saved to model.pth")
|
print("Model saved to model.pth")
|
||||||
|
|
||||||
@ -144,3 +161,22 @@ print(f"Accuracy: {accuracy:.4f}")
|
|||||||
print(f"Precision: {precision:.4f}")
|
print(f"Precision: {precision:.4f}")
|
||||||
print(f"Recall: {recall:.4f}")
|
print(f"Recall: {recall:.4f}")
|
||||||
print(f"F1 Score: {f1:.4f}")
|
print(f"F1 Score: {f1:.4f}")
|
||||||
|
|
||||||
|
if graph:
|
||||||
|
x = np.arange(1, epochs + 1, 1)
|
||||||
|
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
plt.plot(x, train_losses, color="red", label="Strata trénovania")
|
||||||
|
plt.plot(x, test_losses, color="blue", label="Strata testovania")
|
||||||
|
|
||||||
|
plt.xlabel("Epocha")
|
||||||
|
plt.ylabel("Strata")
|
||||||
|
plt.title("Priebeh trénovania")
|
||||||
|
|
||||||
|
plt.text(0.99, 0.99,
|
||||||
|
f"Presnosť: {accuracy:.4f}\nPrecíznosť: {precision:.4f}\nOdvolanie: {recall:.4f}\nF1 skóre: {f1:.4f}",
|
||||||
|
ha="right", va="top", transform=plt.gca().transAxes, fontweight="bold")
|
||||||
|
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user