diff --git a/mlp_train.py b/mlp_train.py
index 71c0250af2bc429f4978226ca410110060a5ec83..6980e9d4b05c8421b47f15aa12817b8acd0df3aa 100644
--- a/mlp_train.py
+++ b/mlp_train.py
@@ -4,32 +4,66 @@ import numpy as np
 import torch
 from tqdm import trange, tqdm
 from sklearn.model_selection import train_test_split
-import time
 import matplotlib.pyplot as plt
 from sklearn.metrics import confusion_matrix
 import seaborn as sns
+from sklearn.metrics import roc_auc_score
+
+
+# S_features_filter = [abs_S_Smin,rel_S_Smin_semi_width,rel_S_Smin_full_width,abs_S_Smax,rel_S_Smax_semi_width,rel_S_Smax_full_width,count_anomalies_S,ratio_anomalies_S,max_variation_S]
+# T_features_filter = [abs_T_Tmin,rel_T_Tmin_semi_width,rel_T_Tmin_full_width,abs_T_Tmax,rel_T_Tmax_semi_width,rel_T_Tmax_full_width,count_anomalies_T,ratio_anomalies_T,max_variation_T]
+# B_features_filter = [mean_correlation,nb_measurements]
+
+S_features_filter = [
+    "abs_S_Smin",
+    "rel_S_Smin_semi_width",
+    "rel_S_Smin_full_width",
+    "abs_S_Smax",
+    "rel_S_Smax_semi_width",
+    "rel_S_Smax_full_width",
+    "count_anomalies_S",
+    "ratio_anomalies_S",
+    "max_variation_S",
+]
+T_features_filter = [
+    "abs_T_Tmin",
+    "rel_T_Tmin_semi_width",
+    "rel_T_Tmin_full_width",
+    "abs_T_Tmax",
+    "rel_T_Tmax_semi_width",
+    "rel_T_Tmax_full_width",
+    "count_anomalies_T",
+    "ratio_anomalies_T",
+    "max_variation_T",
+]
+B_features_filter = ["mean_correlation", "nb_measurements"]
 
 PICKLE_PATH = "dataset_pandas/temperature.pkl"
 
 
 ##### HYPERPARAMETERS #####
 EPOCHS = 300
-BATCH_SIZE = 16
+BATCH_SIZE = 32
 CRITERION = nn.BCELoss()
 OPTIMIZER = torch.optim.Adam
 LEARNING_RATE = 0.01
 GROWTH_RATE = 16
 DROP_RATE = 0.5
-SCHEDULER_PATIENCE = 20
+SCHEDULER_PATIENCE = 10
 SCHEDULER_FACTOR = 0.5
 SCHEDULER_EPS = 1e-8
 
 
+input_features = 11
+
+
 class MLP(nn.Module):
     def __init__(self):
         super(MLP, self).__init__()
         self.block1 = nn.Sequential(
-            nn.Linear(20, GROWTH_RATE), nn.BatchNorm1d(GROWTH_RATE), nn.ReLU()
+            nn.Linear(input_features, GROWTH_RATE),
+            nn.BatchNorm1d(GROWTH_RATE),
+            nn.ReLU(),
         )
         self.block2 = nn.Sequential(
             nn.Linear(GROWTH_RATE, GROWTH_RATE), nn.BatchNorm1d(GROWTH_RATE), nn.ReLU()
@@ -68,25 +102,42 @@ class MLP(nn.Module):
         return y
 
 
-def prepare_data() -> tuple[np.ndarray, np.ndarray]:
+def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
     # Load data
     df = pd.read_pickle(PICKLE_PATH)
 
-    X = df.drop(columns=["alarm"]).to_numpy()
-    y = df["alarm"].to_numpy()
+    # drop Salinity Features
+    df = df.drop(columns=S_features_filter)
 
-    assert X.shape[1] == 20, "Number of features should be 20"
-    assert y.shape[0] == X.shape[0], "Number of labels should match number of samples"
+    # split the into training and testing sets
+    train_df, test_df = train_test_split(df, test_size=0.2, random_state=123456789)
 
-    return X, y
+    print(
+        f"Train alarm distribution (befor undersampling): {train_df['alarm'].value_counts()}"
+    )
 
+    # balance the training set
+    min_alarm = train_df["alarm"].value_counts().min()
+    train_df = pd.concat(
+        [
+            train_df[train_df["alarm"] == 0].sample(min_alarm),
+            train_df[train_df["alarm"] == 1].sample(min_alarm),
+        ]
+    )
+    print(f"Train alarm distribution: {train_df['alarm'].value_counts()}")
 
-def train_model(X: np.ndarray, y: np.ndarray):
+    X_train = train_df.drop(columns=["alarm"]).values
+    y_train = train_df["alarm"].values
 
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    X_test = test_df.drop(columns=["alarm"]).values
+    y_test = test_df["alarm"].values
 
-    # Splitting the data into training and testing sets
-    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
+    return X_train, y_train, X_test, y_test
+
+
+def train_model(X_train, y_train, X_test, y_test):
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     # Setting up the data loader
     train_loader = torch.utils.data.DataLoader(
@@ -191,22 +242,45 @@ def train_model(X: np.ndarray, y: np.ndarray):
     # calculate confusion matrix
     cm = confusion_matrix(y_test, y_test_pred_binary)
 
+    print(cm)
+
+    # calculate accuracy
+    accuracy = np.sum(np.diag(cm)) / np.sum(cm)
+    print(f"Accuracy: {accuracy}")
+
+    # print recall
+    recall = cm[1, 1] / (cm[1, 0] + cm[1, 1])
+    print(f"Recall: {recall}")
+
+    # print precision
+    precision = cm[1, 1] / (cm[0, 1] + cm[1, 1])
+    print(f"Precision: {precision}")
+
+    # print F1 score
+    f1 = 2 * (precision * recall) / (precision + recall)
+    print(f"F1 score: {f1}")
+
+    # print F2 score
+    f2 = 5 * (precision * recall) / (4 * precision + recall)
+    print(f"F2 score: {f2}")
+
+    # print AUC
+    auc = roc_auc_score(y_test, y_test_pred)
+
+    print(f"AUC: {auc}")
+
     # plot confusion matrix using seaborn
     sns.heatmap(cm, annot=True, fmt="d")
     plt.xlabel("Predicted")
     plt.ylabel("True")
     plt.show()
 
-    # calculate accuracy
-    accuracy = np.sum(np.diag(cm)) / np.sum(cm)
-    print(f"Accuracy: {accuracy}")
-
     return model
 
 
 def main():
-    X, y = prepare_data()
-    model = train_model(X, y)
+    X_train, y_train, X_test, y_test = prepare_data()
+    model = train_model(X_train, y_train, X_test, y_test)
 
     # print parameter count of model
     print(f"Parameter count: {sum(p.numel() for p in model.parameters())}")