Changes structure and adds metrics

This commit is contained in:
Daniel Svitan 2024-12-22 18:30:29 +01:00
parent e8bacc49c0
commit c970d13b43

View File

@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
@ -10,35 +11,31 @@ class NeuralNetwork(nn.Module):
super().__init__()
self.features = nn.Sequential(
nn.Linear(8, 256),
nn.Linear(8, 512),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.BatchNorm1d(512),
nn.Dropout(p=0.4),
nn.Linear(256, 256),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.BatchNorm1d(256),
nn.Dropout(p=0.4),
nn.Linear(256, 256),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.BatchNorm1d(128),
nn.Dropout(p=0.4),
nn.Linear(256, 256),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.BatchNorm1d(64),
nn.Dropout(p=0.4),
nn.Linear(256, 4),
nn.ReLU()
nn.Linear(64, 4)
)
self.lstm = nn.LSTM(input_size=4, hidden_size=64, batch_first=True)
self.output_layer = nn.Linear(64, 4)
def forward(self, x):
x = self.features(x)
x = x.unsqueeze(1)
lstm_out, _ = self.lstm(x)
x = lstm_out.squeeze(1)
return self.output_layer(x)
return self.features(x)
data = np.load("clean.npy")
@ -54,7 +51,7 @@ test_size = len(data) - train_size
torch.manual_seed(42)
batch_size = 32
epochs = 100
epochs = 200
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -76,7 +73,8 @@ print(model)
print("")
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}...\r", end="")
@ -117,6 +115,7 @@ for epoch in range(epochs):
f"Train loss: {train_loss:.4f}\n"
f"Test loss: {test_loss:.4f}\n"
)
scheduler.step(test_loss)
torch.save(model.state_dict(), "model.pth")
print("Model saved to model.pth")