diff --git a/train_nn.py b/train_nn.py index 4f23d34..5625a03 100644 --- a/train_nn.py +++ b/train_nn.py @@ -119,7 +119,7 @@ for epoch in range(epochs): pred = model(X) loss = loss_fn(pred, y) - test_loss = loss.item() * X.size(0) + test_loss += loss.item() * X.size(0) test_loss /= len(test_dataset) test_losses.append(test_loss)