diff --git a/train_nn.py b/train_nn.py index 836aadc..182f16a 100644 --- a/train_nn.py +++ b/train_nn.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn as nn from torch.utils.data import random_split, DataLoader, TensorDataset +from torch.optim.lr_scheduler import ReduceLROnPlateau from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score @@ -10,35 +11,31 @@ class NeuralNetwork(nn.Module): super().__init__() self.features = nn.Sequential( - nn.Linear(8, 256), + nn.Linear(8, 512), nn.ReLU(), - nn.Dropout(p=0.2), + nn.BatchNorm1d(512), + nn.Dropout(p=0.4), - nn.Linear(256, 256), + nn.Linear(512, 256), nn.ReLU(), - nn.Dropout(p=0.2), + nn.BatchNorm1d(256), + nn.Dropout(p=0.4), - nn.Linear(256, 256), + nn.Linear(256, 128), nn.ReLU(), - nn.Dropout(p=0.2), + nn.BatchNorm1d(128), + nn.Dropout(p=0.4), - nn.Linear(256, 256), + nn.Linear(128, 64), nn.ReLU(), - nn.Dropout(p=0.2), + nn.BatchNorm1d(64), + nn.Dropout(p=0.4), - nn.Linear(256, 4), - nn.ReLU() + nn.Linear(64, 4) ) - self.lstm = nn.LSTM(input_size=4, hidden_size=64, batch_first=True) - self.output_layer = nn.Linear(64, 4) - def forward(self, x): - x = self.features(x) - x = x.unsqueeze(1) - lstm_out, _ = self.lstm(x) - x = lstm_out.squeeze(1) - return self.output_layer(x) + return self.features(x) data = np.load("clean.npy") @@ -54,7 +51,7 @@ test_size = len(data) - train_size torch.manual_seed(42) batch_size = 32 -epochs = 100 +epochs = 200 lr = 0.001 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -76,7 +73,8 @@ print(model) print("") loss_fn = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr=lr) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) +scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5) for epoch in range(epochs): print(f"Epoch {epoch + 1}/{epochs}...\r", end="") @@ -117,6 +115,7 @@ for epoch in range(epochs): f"Train loss: {train_loss:.4f}\n" f"Test loss: {test_loss:.4f}\n" ) + scheduler.step(test_loss) torch.save(model.state_dict(), "model.pth") print("Model saved to model.pth")