diff --git a/analysis.py b/analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3c9b10e478dbd20e7317e1964fa69275d94abb
--- /dev/null
+++ b/analysis.py
@@ -0,0 +1,21 @@
+import pandas as pd
+from sklearn.preprocessing import StandardScaler
+from scipy.stats import gaussian_kde
+
+def get_joint_pdf(df_train, df_test):
+    # Plot the joint PDF of all 10 features
+    scaler = StandardScaler()
+    df_train_normalized = pd.DataFrame(scaler.fit_transform(df_train), columns=df_train.columns)
+    df_test_normalized = pd.DataFrame(scaler.fit_transform(df_test), columns=df_test.columns)
+
+    #sns.set_theme(style="white", palette="muted", color_codes=True)
+    #sns.kdeplot(data=df_normalized, shade=True)
+    #sns.jointplot(data=df_normalized, kind="kde", height=7, space=0)
+
+    kde = gaussian_kde(df_train_normalized.T)
+
+    # Evaluate the KDE 
+    joint_pdf_train = kde(df_train_normalized.T)
+    joint_pdf_test = kde(df_test_normalized.T)
+
+    return joint_pdf_train, joint_pdf_test
\ No newline at end of file
diff --git a/logs/fix-data-param-romaric.csv b/logs/fix-data-param-romaric.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0886519d1c7e674d1a8fab72b3405e83cb8e8497
--- /dev/null
+++ b/logs/fix-data-param-romaric.csv
@@ -0,0 +1,21 @@
+DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
+123456789, 821, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9410755148741418
+123456789, 6712, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9445080091533181
+123456789, 8255, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
+123456789, 4502, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.908466819221968
+123456789, 3403, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9199084668192219
+123456789, 3008, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9296338672768879
+123456789, 123, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9336384439359268
+123456789, 6564, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9284897025171625
+123456789, 4610, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9290617848970252
+123456789, 3914, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9364988558352403
+123456789, 2658, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9256292906178489
+123456789, 8773, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.937070938215103
+123456789, 4258, 0.7619047619047619, 0.9130434782608695, 0.7241379310344828, 0.8076923076923076, 0.8677685950413221, 0.9221967963386727
+123456789, 1430, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.933066361556064
+123456789, 2979, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9307780320366132
+123456789, 3935, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9233409610983982
+123456789, 6961, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9290617848970252
+123456789, 2381, 0.7738095238095238, 0.9130434782608695, 0.7368421052631579, 0.8155339805825242, 0.871369294605809, 0.9296338672768879
+123456789, 4368, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9439359267734554
+123456789, 3816, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9336384439359268
diff --git a/logs/fix-data.csv b/logs/fix-data.csv
new file mode 100644
index 0000000000000000000000000000000000000000..59954d86dd70017970b4a8ca40d34b1ed1507977
--- /dev/null
+++ b/logs/fix-data.csv
@@ -0,0 +1,27 @@
+DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
+123456789, 4224, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9290617848970252
+123456789, 7729, 0.7857142857142857, 0.9782608695652174, 0.7258064516129032, 0.8333333333333333, 0.9146341463414634, 0.9170480549199085
+123456789, 197, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9216247139588101
+123456789, 8678, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.927345537757437
+123456789, 1418, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9193363844393593
+123456789, 1742, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9170480549199085
+123456789, 2456, 0.7857142857142857, 0.8913043478260869, 0.7592592592592593, 0.8200000000000001, 0.8613445378151261, 0.9302059496567505
+123456789, 5305, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.8844393592677346
+123456789, 6591, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9262013729977117
+123456789, 4043, 0.75, 0.9130434782608695, 0.711864406779661, 0.7999999999999999, 0.8641975308641975, 0.9090389016018308
+123456789, 6737, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9302059496567506
+123456789, 3856, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.925629290617849
+123456789, 4955, 0.7380952380952381, 0.9130434782608695, 0.7, 0.7924528301886793, 0.8606557377049181, 0.9193363844393593
+123456789, 3355, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9124713958810069
+123456789, 9015, 0.7142857142857143, 0.9565217391304348, 0.6666666666666666, 0.7857142857142856, 0.88, 0.9204805491990846
+123456789, 4165, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9347826086956522
+123456789, 4498, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9244851258581236
+123456789, 7503, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9279176201372998
+123456789, 7120, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.8781464530892449
+123456789, 5406, 0.7976190476190477, 0.9130434782608695, 0.7636363636363637, 0.8316831683168316, 0.8786610878661087, 0.9290617848970252
+123456789, 1173, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9147597254004577
+123456789, 810, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9290617848970252
+123456789, 9865, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.8970251716247141
+123456789, 5329, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
+
+
diff --git a/logs/fix-model-param-romaric.csv b/logs/fix-model-param-romaric.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ca46e4d399ceddd3b6525f5b47dbe9deeb32f6b4
--- /dev/null
+++ b/logs/fix-model-param-romaric.csv
@@ -0,0 +1,21 @@
+DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
+5517, 123456789, 0.7261904761904762, 0.9, 0.7142857142857143, 0.7964601769911505, 0.8555133079847909, 0.8858823529411763
+29, 123456789, 0.7261904761904762, 0.8421052631578947, 0.6530612244897959, 0.735632183908046, 0.7960199004975124, 0.8838672768878718
+1651, 123456789, 0.7857142857142857, 0.92, 0.7666666666666667, 0.8363636363636363, 0.8846153846153846, 0.8758823529411764
+7768, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8997150997150998
+6584, 123456789, 0.7261904761904762, 0.8703703703703703, 0.746031746031746, 0.8034188034188035, 0.8422939068100358, 0.8728395061728396
+4261, 123456789, 0.8095238095238095, 0.9019607843137255, 0.8070175438596491, 0.8518518518518519, 0.8812260536398467, 0.9067142008318478
+9265, 123456789, 0.7023809523809523, 0.94, 0.6811594202898551, 0.7899159663865546, 0.8736059479553903, 0.8158823529411765
+9010, 123456789, 0.7261904761904762, 0.9777777777777777, 0.6666666666666666, 0.7927927927927928, 0.8943089430894309, 0.9287749287749287
+1586, 123456789, 0.7619047619047619, 0.9183673469387755, 0.7377049180327869, 0.8181818181818182, 0.8754863813229573, 0.8641399416909622
+7263, 123456789, 0.7261904761904762, 0.8444444444444444, 0.7037037037037037, 0.7676767676767676, 0.811965811965812, 0.8655270655270656
+4156, 123456789, 0.7380952380952381, 0.9375, 0.703125, 0.8035714285714286, 0.87890625, 0.9184027777777777
+7070, 123456789, 0.7142857142857143, 0.9347826086956522, 0.671875, 0.7818181818181819, 0.8669354838709679, 0.8770022883295194
+8650, 123456789, 0.7142857142857143, 0.803921568627451, 0.7454545454545455, 0.7735849056603775, 0.7915057915057916, 0.863339275103981
+5088, 123456789, 0.5952380952380952, 1.0, 0.569620253164557, 0.7258064516129034, 0.8687258687258687, 0.8119658119658121
+9173, 123456789, 0.7023809523809523, 0.84, 0.711864406779661, 0.7706422018348624, 0.8108108108108109, 0.858235294117647
+4985, 123456789, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.88558352402746
+99, 123456789, 0.7380952380952381, 0.9069767441860465, 0.6842105263157895, 0.78, 0.851528384279476, 0.8774815655133296
+9309, 123456789, 0.7142857142857143, 0.9318181818181818, 0.6612903225806451, 0.7735849056603773, 0.861344537815126, 0.86875
+9098, 123456789, 0.7738095238095238, 0.8461538461538461, 0.8, 0.8224299065420562, 0.8365019011406846, 0.9038461538461539
+3474, 123456789, 0.6309523809523809, 0.82, 0.6507936507936508, 0.7256637168141592, 0.779467680608365, 0.8188235294117647
diff --git a/logs/fix-model.csv b/logs/fix-model.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c6aacb69b0e7d5847e209f2c84f57f9a97c633bd
--- /dev/null
+++ b/logs/fix-model.csv
@@ -0,0 +1,24 @@
+DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
+2659, 123456789, 0.7619047619047619, 0.86, 0.7678571428571429, 0.8113207547169812, 0.83984375, 0.8870588235294117
+5355, 123456789, 0.7738095238095238, 0.9111111111111111, 0.7321428571428571, 0.8118811881188118, 0.8686440677966102, 0.9059829059829061
+6402, 123456789, 0.8095238095238095, 0.9230769230769231, 0.8, 0.8571428571428571, 0.8955223880597014, 0.907451923076923
+8817, 123456789, 0.7380952380952381, 1.0, 0.6271186440677966, 0.7708333333333333, 0.8937198067632851, 0.9355951696377227
+7097, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8831908831908832
+3543, 123456789, 0.7976190476190477, 0.9056603773584906, 0.8, 0.8495575221238938, 0.8823529411764706, 0.8746195982958004
+3630, 123456789, 0.7261904761904762, 0.8, 0.7547169811320755, 0.7766990291262137, 0.790513833992095, 0.8400000000000001
+481, 123456789, 0.7619047619047619, 0.8297872340425532, 0.7647058823529411, 0.7959183673469387, 0.8158995815899581, 0.8964922369177688
+8874, 123456789, 0.8214285714285714, 0.9555555555555556, 0.7678571428571429, 0.8514851485148515, 0.9110169491525425, 0.8968660968660969
+2423, 123456789, 0.75, 0.9, 0.7377049180327869, 0.8108108108108109, 0.8620689655172414, 0.8847058823529411
+572, 123456789, 0.7380952380952381, 0.8775510204081632, 0.7288135593220338, 0.7962962962962963, 0.8431372549019608, 0.8256559766763849
+9467, 123456789, 0.7857142857142857, 0.74, 0.8809523809523809, 0.8043478260869565, 0.7644628099173554, 0.91
+4669, 123456789, 0.75, 0.8703703703703703, 0.7704918032786885, 0.817391304347826, 0.848375451263538, 0.880246913580247
+3840, 123456789, 0.7380952380952381, 0.9347826086956522, 0.6935483870967742, 0.7962962962962964, 0.8739837398373984, 0.8827231121281465
+717, 123456789, 0.7380952380952381, 0.8636363636363636, 0.7037037037037037, 0.7755102040816326, 0.8260869565217391, 0.8522727272727272
+4353, 123456789, 0.7738095238095238, 0.9782608695652174, 0.7142857142857143, 0.8256880733944955, 0.9109311740890689, 0.9399313501144165
+7024, 123456789, 0.7619047619047619, 0.8043478260869565, 0.7708333333333334, 0.7872340425531915, 0.7974137931034484, 0.8775743707093822
+5849, 123456789, 0.7857142857142857, 0.96, 0.75, 0.8421052631578947, 0.9090909090909091, 0.9252941176470586
+7847, 123456789, 0.7380952380952381, 0.8095238095238095, 0.7083333333333334, 0.7555555555555556, 0.7870370370370371, 0.8781179138321996
+6263, 123456789, 0.8095238095238095, 0.9137931034482759, 0.828125, 0.8688524590163935, 0.8952702702702704, 0.8925729442970822
+107, 123456789, 0.7380952380952381, 0.9166666666666666, 0.7096774193548387, 0.7999999999999999, 0.8661417322834645, 0.8790509259259259
+
+
diff --git a/logs/random.csv b/logs/random.csv
new file mode 100644
index 0000000000000000000000000000000000000000..6ee86c2ee18f17888c112a4f352f8dcee10cd114
--- /dev/null
+++ b/logs/random.csv
@@ -0,0 +1,12 @@
+Run,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
+8277, 0.6785714285714286, 0.8181818181818182, 0.6545454545454545, 0.7272727272727274, 0.7792207792207791, 0.8636363636363636
+9552, 0.6309523809523809, 0.9142857142857143, 0.5333333333333333, 0.6736842105263158, 0.8000000000000002, 0.8478134110787172
+8712, 0.6666666666666666, 0.8181818181818182, 0.6428571428571429, 0.7200000000000001, 0.7758620689655173, 0.8045454545454546
+1418, 0.7142857142857143, 0.8235294117647058, 0.7368421052631579, 0.7777777777777778, 0.8045977011494252, 0.8354129530600121
+3720, 0.7380952380952381, 0.82, 0.7592592592592593, 0.7884615384615384, 0.8070866141732284, 0.7582352941176471
+9510, 0.8214285714285714, 0.8936170212765957, 0.8076923076923077, 0.8484848484848485, 0.875, 0.8861414606095457
+2, 0.6547619047619048, 0.9166666666666666, 0.6376811594202898, 0.752136752136752, 0.842911877394636, 0.7777777777777777
+209, 0.7619047619047619, 0.9069767441860465, 0.7090909090909091, 0.7959183673469388, 0.8590308370044054, 0.8740782756664776
+1000, 0.7023809523809523, 0.8222222222222222, 0.6851851851851852, 0.7474747474747475, 0.7905982905982906, 0.8051282051282052
+9103, 0.75, 0.8936170212765957, 0.7241379310344828, 0.7999999999999999, 0.8536585365853658, 0.8959171937895342
+7355, 0.75, 0.868421052631579, 0.673469387755102, 0.7586206896551724, 0.8208955223880596, 0.8975972540045767
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..36ab1dad2376ee1df652d255b726b438ce83511a
--- /dev/null
+++ b/main.py
@@ -0,0 +1,31 @@
+import mlp_train as m
+import plot_results as p
+import analysis as a
+
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+def main():
+    for data_seed in np.random.randint(0, 10000, size=10):
+        model_seed = m.RANDOM_SEED
+        train_df, test_df = m.prepare_test_train_df(data_seed)
+        X_train, y_train, X_test, y_test = m.prepare_features_and_labels(train_df, test_df)
+        joint_pdf_train, joint_pdf_test = a.get_joint_pdf(pd.DataFrame(X_train), pd.DataFrame(X_test))
+        model, train_losses, test_losses, train_accuracies, test_accuracies = m.train_model(X_train, y_train, X_test, y_test, model_seed)
+        y_test_pred, y_test_pred_binary = m.get_evaluation(model, X_test)
+        results = m.sort_testdata_into_cm(test_df, y_test_pred, y_test_pred_binary)
+
+        fig, ax = plt.subplots(2, 2, figsize=(24, 6))
+        p.plot_joint_pdf(joint_pdf_train, joint_pdf_test, ax[0, 0])
+        p.plot_train_test_evolution(train_accuracies, test_accuracies, ax[0, 1], metric = 'accuracy')
+        p.plot_train_test_evolution(train_losses, test_losses, ax[1, 0])
+        
+        sns.kdeplot(data = results, x = 'predicted_raw', hue = "CM",  ax=ax[1, 1])
+        
+        #plt.savefig(f"./results/results_data{data_seed}_model{model_seed}.png")
+        plt.show()
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/mlp_from_paper.ipynb b/measurements_exploration.ipynb
similarity index 54%
rename from mlp_from_paper.ipynb
rename to measurements_exploration.ipynb
index acf58f3906033ca087be9e34263274ecdfe9ddf3..a78fb1399bb265c35b2d1b945e8d5f4c1737dfa1 100644
--- a/mlp_from_paper.ipynb
+++ b/measurements_exploration.ipynb
@@ -563,887 +563,6 @@
     "total_temp = total_num_profiles - total_sal\n",
     "print(f'Total #Temperature Measurement : {total_temp}')"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "ValueError",
-     "evalue": "Length mismatch: Expected axis has 0 elements, new values have 28 elements",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame()\n\u001b[0;32m----> 2\u001b[0m \u001b[43mdf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mID\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPARAM\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLON\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLAT\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPOS_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBAT\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mJULD\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mJULD_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFALSEorTRUE\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRES\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRES_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MED\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MIN\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MAX\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MED\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MIN\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MAX\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRESs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MEDs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MINs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MAXs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MEDs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MINs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MAXs\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m      3\u001b[0m data \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      4\u001b[0m labels \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mID\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPARAM\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLON\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLAT\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPOS_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBAT\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mJULD\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mJULD_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFALSEorTRUE\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRES\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRES_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MED\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MIN\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MAX\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_QC\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MED\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MIN\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MAX\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPRESs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MEDs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MINs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPSAL_MAXs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MEDs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MINs\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTEMP_MAXs\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
-      "File \u001b[0;32m~/.venv/lib/python3.10/site-packages/pandas/core/generic.py:6310\u001b[0m, in \u001b[0;36mNDFrame.__setattr__\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m   6308\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   6309\u001b[0m     \u001b[38;5;28mobject\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattribute__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name)\n\u001b[0;32m-> 6310\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__setattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   6311\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m   6312\u001b[0m     \u001b[38;5;28;01mpass\u001b[39;00m\n",
-      "File \u001b[0;32mproperties.pyx:69\u001b[0m, in \u001b[0;36mpandas._libs.properties.AxisProperty.__set__\u001b[0;34m()\u001b[0m\n",
-      "File \u001b[0;32m~/.venv/lib/python3.10/site-packages/pandas/core/generic.py:813\u001b[0m, in \u001b[0;36mNDFrame._set_axis\u001b[0;34m(self, axis, labels)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    809\u001b[0m \u001b[38;5;124;03mThis is called from the cython code when we set the `index` attribute\u001b[39;00m\n\u001b[1;32m    810\u001b[0m \u001b[38;5;124;03mdirectly, e.g. `series.index = [1, 2, 3]`.\u001b[39;00m\n\u001b[1;32m    811\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    812\u001b[0m labels \u001b[38;5;241m=\u001b[39m ensure_index(labels)\n\u001b[0;32m--> 813\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_mgr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    814\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_clear_item_cache()\n",
-      "File \u001b[0;32m~/.venv/lib/python3.10/site-packages/pandas/core/internals/managers.py:238\u001b[0m, in \u001b[0;36mBaseBlockManager.set_axis\u001b[0;34m(self, axis, new_labels)\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mset_axis\u001b[39m(\u001b[38;5;28mself\u001b[39m, axis: AxisInt, new_labels: Index) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    237\u001b[0m     \u001b[38;5;66;03m# Caller is responsible for ensuring we have an Index object.\u001b[39;00m\n\u001b[0;32m--> 238\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_set_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew_labels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    239\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes[axis] \u001b[38;5;241m=\u001b[39m new_labels\n",
-      "File \u001b[0;32m~/.venv/lib/python3.10/site-packages/pandas/core/internals/base.py:98\u001b[0m, in \u001b[0;36mDataManager._validate_set_axis\u001b[0;34m(self, axis, new_labels)\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m     97\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m new_len \u001b[38;5;241m!=\u001b[39m old_len:\n\u001b[0;32m---> 98\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m     99\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLength mismatch: Expected axis has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mold_len\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m elements, new \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    100\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalues have \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnew_len\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m elements\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    101\u001b[0m     )\n",
-      "\u001b[0;31mValueError\u001b[0m: Length mismatch: Expected axis has 0 elements, new values have 28 elements"
-     ]
-    }
-   ],
-   "source": [
-    "df = pd.DataFrame()\n",
-    "df.columns = ['ID', 'PARAM', 'LON', 'LAT', 'POS_QC', 'BAT', 'JULD', 'JULD_QC', 'FALSEorTRUE', 'PRES', 'PRES_QC', 'PSAL', 'PSAL_QC', 'PSAL_MED', 'PSAL_MIN', 'PSAL_MAX', 'TEMP', 'TEMP_QC', 'TEMP_MED', 'TEMP_MIN', 'TEMP_MAX', 'PRESs', 'PSAL_MEDs', 'PSAL_MINs', 'PSAL_MAXs', 'TEMP_MEDs', 'TEMP_MINs', 'TEMP_MAXs']\n",
-    "data = []\n",
-    "labels = ['ID', 'PARAM', 'LON', 'LAT', 'POS_QC', 'BAT', 'JULD', 'JULD_QC', 'FALSEorTRUE', 'PRES', 'PRES_QC', 'PSAL', 'PSAL_QC', 'PSAL_MED', 'PSAL_MIN', 'PSAL_MAX', 'TEMP', 'TEMP_QC', 'TEMP_MED', 'TEMP_MIN', 'TEMP_MAX', 'PRESs', 'PSAL_MEDs', 'PSAL_MINs', 'PSAL_MAXs', 'TEMP_MEDs', 'TEMP_MINs', 'TEMP_MAXs']\n",
-    "for filename in os.listdir(\"./dataset\"):\n",
-    "    file2read = netcdf.NetCDFFile(\"./dataset/\"+filename, 'r')\n",
-    "    print(\"\\n\".join(str(file2read.variables).split(',')))\n",
-    "    labels = file2read.variables.keys()\n",
-    "    break\n",
-    "    #data.append(temp[:]*1)\n",
-    "    #file2read.close()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 23,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "ID(1,)\n",
-      "[5902393.]\n",
-      "PARAM(1,)\n",
-      "[b'S']\n",
-      "LON(1,)\n",
-      "[-176.45863342]\n",
-      "LAT(1,)\n",
-      "[-29.23719025]\n",
-      "POS_QC(1,)\n",
-      "[1.]\n",
-      "BAT(1,)\n",
-      "[1.]\n",
-      "JULD(1,)\n",
-      "[737288.87855324]\n",
-      "JULD_QC(1,)\n",
-      "[1.]\n",
-      "FALSEorTRUE(1,)\n",
-      "[1.]\n",
-      "PRES(1, 1008)\n",
-      "[[1.08000004e+00 2.03999996e+00 3.03999996e+00 ... 2.00192004e+03\n",
-      "  2.00400000e+03 2.00607996e+03]]\n",
-      "PRES_QC(1, 1008)\n",
-      "[[1. 1. 1. ... 1. 1. 1.]]\n",
-      "PSAL(1, 1008)\n",
-      "[[35.69332123 35.69332123 35.69332123 ... 34.62292099 34.62311935\n",
-      "  34.62332153]]\n",
-      "PSAL_QC(1, 1008)\n",
-      "[[1. 1. 1. ... 1. 1. 1.]]\n",
-      "PSAL_MED(1, 1008)\n",
-      "[[35.612     35.612     35.612     ... 34.6185015 34.6185015 34.6185015]]\n",
-      "PSAL_MIN(1, 1008)\n",
-      "[[35.41     35.41     35.41     ... 34.608002 34.608002 34.608002]]\n",
-      "PSAL_MAX(1, 1008)\n",
-      "[[35.736 35.736 35.736 ... 34.632 34.632 34.632]]\n",
-      "TEMP(1, 1008)\n",
-      "[[19.04599953 19.04500008 19.04500008 ...  2.36599994  2.36299992\n",
-      "   2.36299992]]\n",
-      "TEMP_QC(1, 1008)\n",
-      "[[1. 1. 1. ... 1. 1. 1.]]\n",
-      "TEMP_MED(1, 1008)\n",
-      "[[21.983 21.983 21.983 ...  2.387  2.387  2.387]]\n",
-      "TEMP_MIN(1, 1008)\n",
-      "[[18.431999 18.431999 18.431999 ...  2.303     2.303     2.303   ]]\n",
-      "TEMP_MAX(1, 1008)\n",
-      "[[25.636999 25.636999 25.636999 ...  2.458     2.458     2.458   ]]\n",
-      "PRESs(200,)\n",
-      "[  -5.   20.   20.   40.   40.   60.   60.   80.   80.  100.  100.  120.\n",
-      "  120.  140.  140.  160.  160.  180.  180.  200.  200.  220.  220.  240.\n",
-      "  240.  260.  260.  280.  280.  300.  300.  320.  320.  340.  340.  360.\n",
-      "  360.  380.  380.  400.  400.  420.  420.  440.  440.  460.  460.  480.\n",
-      "  480.  500.  500.  520.  520.  540.  540.  560.  560.  580.  580.  600.\n",
-      "  600.  620.  620.  640.  640.  660.  660.  680.  680.  700.  700.  720.\n",
-      "  720.  740.  740.  760.  760.  780.  780.  800.  800.  820.  820.  840.\n",
-      "  840.  860.  860.  880.  880.  900.  900.  920.  920.  940.  940.  960.\n",
-      "  960.  980.  980. 1000. 1000. 1020. 1020. 1040. 1040. 1060. 1060. 1080.\n",
-      " 1080. 1100. 1100. 1120. 1120. 1140. 1140. 1160. 1160. 1180. 1180. 1200.\n",
-      " 1200. 1220. 1220. 1240. 1240. 1260. 1260. 1280. 1280. 1300. 1300. 1320.\n",
-      " 1320. 1340. 1340. 1360. 1360. 1380. 1380. 1400. 1400. 1420. 1420. 1440.\n",
-      " 1440. 1460. 1460. 1480. 1480. 1500. 1500. 1520. 1520. 1540. 1540. 1560.\n",
-      " 1560. 1580. 1580. 1600. 1600. 1620. 1620. 1640. 1640. 1660. 1660. 1680.\n",
-      " 1680. 1700. 1700. 1720. 1720. 1740. 1740. 1760. 1760. 1780. 1780. 1800.\n",
-      " 1800. 1820. 1820. 1840. 1840. 1860. 1860. 1880. 1880. 1900. 1900. 1920.\n",
-      " 1920. 1940. 1940. 1960. 1960. 1980. 1980. 2050.]\n",
-      "PSAL_MEDs(1, 200)\n",
-      "[[35.612     35.612     35.6461885 35.6461885 35.6774995 35.6774995\n",
-      "  35.674999  35.674999  35.66      35.66      35.640501  35.640501\n",
-      "  35.6064985 35.6064985 35.576     35.576     35.5555    35.5555\n",
-      "  35.525002  35.525002  35.487     35.487     35.461536  35.461536\n",
-      "  35.421692  35.421692  35.3605005 35.3605005 35.317001  35.317001\n",
-      "  35.263216  35.263216  35.2435    35.2435    35.1461505 35.1461505\n",
-      "  35.079507  35.079507  35.007     35.007     34.972386  34.972386\n",
-      "  34.904515  34.904515  34.841     34.841     34.784414  34.784414\n",
-      "  34.706001  34.706001  34.679986  34.679986  34.635945  34.635945\n",
-      "  34.588619  34.588619  34.549747  34.549747  34.502731  34.502731\n",
-      "  34.487406  34.487406  34.466534  34.466534  34.433998  34.433998\n",
-      "  34.421313  34.421313  34.402378  34.402378  34.39606   34.39606\n",
-      "  34.38526   34.38526   34.375     34.375     34.3653305 34.3653305\n",
-      "  34.356304  34.356304  34.3543625 34.3543625 34.3521    34.3521\n",
-      "  34.350195  34.350195  34.34843   34.34843   34.3466725 34.3466725\n",
-      "  34.3461165 34.3461165 34.344837  34.344837  34.348     34.348\n",
-      "  34.347054  34.347054  34.3480735 34.3480735 34.3474315 34.3474315\n",
-      "  34.3550305 34.3550305 34.357611  34.357611  34.3624655 34.3624655\n",
-      "  34.370628  34.370628  34.372883  34.372883  34.377059  34.377059\n",
-      "  34.3822305 34.3822305 34.3906215 34.3906215 34.402714  34.402714\n",
-      "  34.4075905 34.4075905 34.4167915 34.4167915 34.4256685 34.4256685\n",
-      "  34.434753  34.434753  34.4476965 34.4476965 34.453896  34.453896\n",
-      "  34.4657205 34.4657205 34.4771405 34.4771405 34.48754   34.48754\n",
-      "  34.502121  34.502121  34.5070415 34.5070415 34.5160115 34.5160115\n",
-      "  34.523035  34.523035  34.529335  34.529335  34.537029  34.537029\n",
-      "  34.543737  34.543737  34.548712  34.548712  34.553499  34.553499\n",
-      "  34.560501  34.560501  34.5650405 34.5650405 34.572206  34.572206\n",
-      "  34.5781695 34.5781695 34.583     34.583     34.5846395 34.5846395\n",
-      "  34.5865885 34.5865885 34.5877905 34.5877905 34.5900925 34.5900925\n",
-      "  34.591999  34.591999  34.5953285 34.5953285 34.5993835 34.5993835\n",
-      "  34.6009    34.6009    34.603407  34.603407  34.6055015 34.6055015\n",
-      "  34.6075795 34.6075795 34.610514  34.610514  34.611377  34.611377\n",
-      "  34.6129805 34.6129805 34.6145    34.6145    34.617977  34.617977\n",
-      "  34.6185015 34.6185015]]\n",
-      "PSAL_MINs(1, 200)\n",
-      "[[35.41     35.41     35.41     35.41     35.419998 35.419998 35.508507\n",
-      "  35.508507 35.553299 35.553299 35.584999 35.584999 35.551998 35.551998\n",
-      "  35.488998 35.488998 35.417999 35.417999 35.261002 35.261002 35.192001\n",
-      "  35.192001 35.176998 35.176998 35.101002 35.101002 35.042999 35.042999\n",
-      "  35.002998 35.002998 34.978001 34.978001 34.923    34.923    34.827\n",
-      "  34.827    34.796001 34.796001 34.721001 34.721001 34.691449 34.691449\n",
-      "  34.635691 34.635691 34.580002 34.580002 34.543568 34.543568 34.488998\n",
-      "  34.488998 34.478025 34.478025 34.454215 34.454215 34.431    34.431\n",
-      "  34.416534 34.416534 34.396    34.396    34.39206  34.39206  34.382191\n",
-      "  34.382191 34.368999 34.368999 34.365269 34.365269 34.360001 34.360001\n",
-      "  34.356108 34.356108 34.348964 34.348964 34.341999 34.341999 34.340765\n",
-      "  34.340765 34.338001 34.338001 34.336997 34.336997 34.333028 34.333028\n",
-      "  34.328999 34.328999 34.330973 34.330973 34.330002 34.330002 34.330218\n",
-      "  34.330218 34.330613 34.330613 34.331001 34.331001 34.33285  34.33285\n",
-      "  34.334    34.334    34.337686 34.337686 34.338885 34.338885 34.34\n",
-      "  34.34     34.343852 34.343852 34.348999 34.348999 34.351039 34.351039\n",
-      "  34.354619 34.354619 34.358002 34.358002 34.367286 34.367286 34.380001\n",
-      "  34.380001 34.384924 34.384924 34.393795 34.393795 34.402    34.402\n",
-      "  34.411498 34.411498 34.424    34.424    34.43028  34.43028  34.441016\n",
-      "  34.441016 34.451    34.451    34.463032 34.463032 34.479    34.479\n",
-      "  34.482697 34.482697 34.489071 34.489071 34.494999 34.494999 34.503525\n",
-      "  34.503525 34.514999 34.514999 34.517459 34.517459 34.521851 34.521851\n",
-      "  34.526001 34.526001 34.532753 34.532753 34.542    34.542    34.545336\n",
-      "  34.545336 34.551348 34.551348 34.556999 34.556999 34.562556 34.562556\n",
-      "  34.57     34.57     34.572297 34.572297 34.576259 34.576259 34.580002\n",
-      "  34.580002 34.583412 34.583412 34.588001 34.588001 34.589158 34.589158\n",
-      "  34.591153 34.591153 34.592999 34.592999 34.595154 34.595154 34.598\n",
-      "  34.598    34.599389 34.599389 34.601784 34.601784 34.604    34.604\n",
-      "  34.605724 34.605724 34.608002 34.608002]]\n",
-      "PSAL_MAXs(1, 200)\n",
-      "[[35.736    35.736    35.745998 35.745998 35.744999 35.744999 35.715\n",
-      "  35.715    35.699001 35.699001 35.678001 35.678001 35.667999 35.667999\n",
-      "  35.653999 35.653999 35.636002 35.636002 35.627998 35.627998 35.587002\n",
-      "  35.587002 35.577    35.577    35.592999 35.592999 35.525002 35.525002\n",
-      "  35.474998 35.474998 35.425999 35.425999 35.382    35.382    35.333\n",
-      "  35.333    35.220001 35.220001 35.160999 35.160999 35.121681 35.121681\n",
-      "  35.081451 35.081451 35.067001 35.067001 34.989928 34.989928 34.883999\n",
-      "  34.883999 34.837638 34.837638 34.753344 34.753344 34.703999 34.703999\n",
-      "  34.655856 34.655856 34.611    34.611    34.596289 34.596289 34.569782\n",
-      "  34.569782 34.544998 34.544998 34.512099 34.512099 34.466999 34.466999\n",
-      "  34.463536 34.463536 34.457124 34.457124 34.451    34.451    34.436816\n",
-      "  34.436816 34.417    34.417    34.409989 34.409989 34.397007 34.397007\n",
-      "  34.384998 34.384998 34.39103  34.39103  34.398998 34.398998 34.392772\n",
-      "  34.392772 34.382038 34.382038 34.372002 34.372002 34.367995 34.367995\n",
-      "  34.381001 34.381001 34.382727 34.382727 34.385952 34.385952 34.389\n",
-      "  34.389    34.396923 34.396923 34.408001 34.408001 34.411361 34.411361\n",
-      "  34.41776  34.41776  34.424    34.424    34.428107 34.428107 34.433998\n",
-      "  34.433998 34.438672 34.438672 34.447491 34.447491 34.456001 34.456001\n",
-      "  34.465877 34.465877 34.48     34.48     34.484211 34.484211 34.492308\n",
-      "  34.492308 34.5      34.5      34.510019 34.510019 34.523998 34.523998\n",
-      "  34.52851  34.52851  34.536944 34.536944 34.544998 34.544998 34.552077\n",
-      "  34.552077 34.562    34.562    34.564573 34.564573 34.569383 34.569383\n",
-      "  34.574001 34.574001 34.578186 34.578186 34.584    34.584    34.585986\n",
-      "  34.585986 34.589564 34.589564 34.592999 34.592999 34.5955   34.5955\n",
-      "  34.598999 34.598999 34.600743 34.600743 34.603942 34.603942 34.606998\n",
-      "  34.606998 34.609501 34.609501 34.612999 34.612999 34.614292 34.614292\n",
-      "  34.616688 34.616688 34.618999 34.618999 34.620662 34.620662 34.623001\n",
-      "  34.623001 34.624301 34.624301 34.626687 34.626687 34.629002 34.629002\n",
-      "  34.63023  34.63023  34.632    34.632   ]]\n",
-      "TEMP_MEDs(1, 200)\n",
-      "[[21.983     21.983     21.5159995 21.5159995 20.7440005 20.7440005\n",
-      "  20.124001  20.124001  19.3295    19.3295    18.62      18.62\n",
-      "  18.123     18.123     17.679001  17.679001  17.4735005 17.4735005\n",
-      "  16.844     16.844     16.389999  16.389999  16.007     16.007\n",
-      "  15.557     15.557     15.1325    15.1325    14.794     14.794\n",
-      "  14.284     14.284     13.951     13.951     13.484     13.484\n",
-      "  12.91      12.91      12.3475    12.3475    12.173699  12.173699\n",
-      "  11.598598  11.598598  11.035     11.035     10.579048  10.579048\n",
-      "   9.95       9.95       9.705498   9.705498   9.275      9.275\n",
-      "   8.951      8.951      8.608      8.608      8.171      8.171\n",
-      "   7.994871   7.994871   7.755      7.755      7.461      7.461\n",
-      "   7.268191   7.268191   7.059      7.059      6.963477   6.963477\n",
-      "   6.799637   6.799637   6.63       6.63       6.5502645  6.5502645\n",
-      "   6.3785     6.3785     6.316294   6.316294   6.204189   6.204189\n",
-      "   6.1025     6.1025     5.986183   5.986183   5.813      5.813\n",
-      "   5.7490115  5.7490115  5.6148935  5.6148935  5.47       5.47\n",
-      "   5.3792375  5.3792375  5.217      5.217      4.8853685  4.8853685\n",
-      "   4.7673145  4.7673145  4.6525     4.6525     4.546351   4.546351\n",
-      "   4.396      4.396      4.33777    4.33777    4.230815   4.230815\n",
-      "   4.126      4.126      4.0246145  4.0246145  3.8785     3.8785\n",
-      "   3.8298555  3.8298555  3.738073   3.738073   3.6495     3.6495\n",
-      "   3.574027   3.574027   3.458      3.458      3.411365   3.411365\n",
-      "   3.3250035  3.3250035  3.2435     3.2435     3.187438   3.187438\n",
-      "   3.109      3.109      3.081      3.081      3.0252625  3.0252625\n",
-      "   2.979      2.979      2.9427625  2.9427625  2.902      2.902\n",
-      "   2.8661615  2.8661615  2.843538   2.843538   2.8215     2.8215\n",
-      "   2.783216   2.783216   2.765      2.765      2.721182   2.721182\n",
-      "   2.694024   2.694024   2.6725     2.6725     2.647399   2.647399\n",
-      "   2.6152885  2.6152885  2.6070325  2.6070325  2.590732   2.590732\n",
-      "   2.5775     2.5775     2.5630005  2.5630005  2.545584   2.545584\n",
-      "   2.5354155  2.5354155  2.512865   2.512865   2.4945     2.4945\n",
-      "   2.4776765  2.4776765  2.462      2.462      2.4462685  2.4462685\n",
-      "   2.430285   2.430285   2.4155     2.4155     2.376439   2.376439\n",
-      "   2.387      2.387    ]]\n",
-      "TEMP_MINs(1, 200)\n",
-      "[[18.431999 18.431999 18.431    18.431    18.405001 18.405001 17.931999\n",
-      "  17.931999 17.459    17.459    16.898001 16.898001 16.577999 16.577999\n",
-      "  15.978    15.978    15.463    15.463    14.477    14.477    13.937\n",
-      "  13.937    13.662    13.662    13.052    13.052    12.584    12.584\n",
-      "  12.21     12.21     11.975    11.975    11.589    11.589    11.\n",
-      "  11.       10.605    10.605    10.109    10.109     9.849606  9.849606\n",
-      "   9.352636  9.352636  8.883     8.883     8.523503  8.523503  8.026\n",
-      "   8.026     7.902756  7.902756  7.674525  7.674525  7.428     7.428\n",
-      "   7.229613  7.229613  6.948     6.948     6.872269  6.872269  6.733313\n",
-      "   6.733313  6.573     6.573     6.482227  6.482227  6.354     6.354\n",
-      "   6.274196  6.274196  6.127768  6.127768  5.985     5.985     5.880442\n",
-      "   5.880442  5.731     5.731     5.675685  5.675685  5.567992  5.567992\n",
-      "   5.46      5.46      5.340454  5.340454  5.171     5.171     5.112563\n",
-      "   5.112563  5.004347  5.004347  4.806     4.806     4.715908  4.715908\n",
-      "   4.589     4.589     4.508319  4.508319  4.357512  4.357512  4.215\n",
-      "   4.215     4.144945  4.144945  4.047     4.047     4.00143   4.00143\n",
-      "   3.91463   3.91463   3.83      3.83      3.740852  3.740852  3.613\n",
-      "   3.613     3.570728  3.570728  3.490968  3.490968  3.414     3.414\n",
-      "   3.343217  3.343217  3.242     3.242     3.216526  3.216526  3.167538\n",
-      "   3.167538  3.121     3.121     3.062551  3.062551  2.981     2.981\n",
-      "   2.956506  2.956506  2.910723  2.910723  2.867     2.867     2.823285\n",
-      "   2.823285  2.762     2.762     2.74892   2.74892   2.724471  2.724471\n",
-      "   2.701     2.701     2.677982  2.677982  2.646     2.646     2.634084\n",
-      "   2.634084  2.612612  2.612612  2.592     2.592     2.572826  2.572826\n",
-      "   2.546     2.546     2.53619   2.53619   2.51819   2.51819   2.501\n",
-      "   2.501     2.485982  2.485982  2.465     2.465     2.456377  2.456377\n",
-      "   2.440409  2.440409  2.425     2.425     2.410867  2.410867  2.391\n",
-      "   2.391     2.379515  2.379515  2.358441  2.358441  2.338     2.338\n",
-      "   2.323666  2.323666  2.303     2.303   ]]\n",
-      "TEMP_MAXs(1, 200)\n",
-      "[[18.431999 25.636999 18.431    25.01     18.405001 24.504999 17.931999\n",
-      "  23.291    17.459    21.289    16.898001 20.158001 16.577999 19.415001\n",
-      "  15.978    18.871    15.463    18.389    14.477    17.976    13.937\n",
-      "  17.591    13.662    17.225    13.052    16.959    12.584    16.486\n",
-      "  12.21     16.153    11.975    15.706    11.589    15.185    11.\n",
-      "  14.71     10.605    14.071    10.109    13.495     9.849606 13.1825\n",
-      "   9.352636 12.735817  8.883    12.47      8.523503 11.903543  8.026\n",
-      "  11.125     7.902756 10.785378  7.674525 10.167882  7.428     9.825\n",
-      "   7.229613  9.386641  6.948     8.802     6.872269  8.676289  6.733313\n",
-      "   8.449783  6.573     8.238     6.482227  7.965529  6.354     7.592\n",
-      "   6.274196  7.508241  6.127768  7.35313   5.985     7.205     5.880442\n",
-      "   7.044808  5.731     6.821     5.675685  6.718915  5.567992  6.534093\n",
-      "   5.46      6.444     5.340454  6.313121  5.171     6.189     5.112563\n",
-      "   6.09537   5.004347  5.933939  4.806     5.783     4.715908  5.615206\n",
-      "   4.589     5.482     4.508319  5.166117  4.357512  5.037891  4.215\n",
-      "   4.918     4.144945  4.779952  4.047     4.613     4.00143   4.556567\n",
-      "   3.91463   4.457561  3.83      4.364     3.740852  4.21841   3.613\n",
-      "   4.019     3.570728  3.965067  3.490968  3.867889  3.414     3.778\n",
-      "   3.343217  3.685611  3.242     3.564     3.216526  3.522364  3.167538\n",
-      "   3.451191  3.121     3.385     3.062551  3.295189  2.981     3.176\n",
-      "   2.956506  3.155665  2.910723  3.120606  2.867     3.088     2.823285\n",
-      "   3.047502  2.762     2.993     2.74892   2.973998  2.724471  2.940066\n",
-      "   2.701     2.908     2.677982  2.872552  2.646     2.824     2.634084\n",
-      "   2.807984  2.612612  2.779126  2.592     2.752     2.572826  2.7255\n",
-      "   2.546     2.695     2.53619   2.676677  2.51819   2.653707  2.501\n",
-      "   2.632     2.485982  2.605998  2.465     2.592     2.456377  2.560349\n",
-      "   2.440409  2.541986  2.425     2.525     2.410867  2.507335  2.391\n",
-      "   2.519     2.379515  2.476769  2.358441  2.466306  2.338     2.456\n",
-      "   2.323666  2.429212  2.303     2.458   ]]\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/tmp/ipykernel_133376/2058552270.py:4: DeprecationWarning: `scipy.io.netcdf.NetCDFFile` is deprecated along with the `scipy.io.netcdf` namespace. `scipy.io.netcdf.NetCDFFile` will be removed in SciPy 1.14.0, and the `scipy.io.netcdf` namespace will be removed in SciPy 2.0.0.\n",
-      "  file = netcdf.NetCDFFile(\"./dataset/\"+filename, 'r')\n"
-     ]
-    }
-   ],
-   "source": [
-    "data = []\n",
-    "for filename in os.listdir(\"./dataset\"):\n",
-    "    data_dict = {}\n",
-    "    file = netcdf.NetCDFFile(\"./dataset/\"+filename, 'r')\n",
-    "    for key in file.variables.keys():\n",
-    "        print(str(key)+str(file.variables[key].shape))\n",
-    "        array1 = np.array(file.variables[key][:])\n",
-    "        print(array1)\n",
-    "    break"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<div>\n",
-       "<style scoped>\n",
-       "    .dataframe tbody tr th:only-of-type {\n",
-       "        vertical-align: middle;\n",
-       "    }\n",
-       "\n",
-       "    .dataframe tbody tr th {\n",
-       "        vertical-align: top;\n",
-       "    }\n",
-       "\n",
-       "    .dataframe thead th {\n",
-       "        text-align: right;\n",
-       "    }\n",
-       "</style>\n",
-       "<table border=\"1\" class=\"dataframe\">\n",
-       "  <thead>\n",
-       "    <tr style=\"text-align: right;\">\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th>ID</th>\n",
-       "      <th>PARAM</th>\n",
-       "      <th>LON</th>\n",
-       "      <th>LAT</th>\n",
-       "      <th>POS_QC</th>\n",
-       "      <th>BAT</th>\n",
-       "      <th>JULD</th>\n",
-       "      <th>JULD_QC</th>\n",
-       "      <th>FALSEorTRUE</th>\n",
-       "      <th>PRES</th>\n",
-       "      <th>...</th>\n",
-       "      <th>TEMP_MED</th>\n",
-       "      <th>TEMP_MIN</th>\n",
-       "      <th>TEMP_MAX</th>\n",
-       "      <th>PRESs</th>\n",
-       "      <th>PSAL_MEDs</th>\n",
-       "      <th>PSAL_MINs</th>\n",
-       "      <th>PSAL_MAXs</th>\n",
-       "      <th>TEMP_MEDs</th>\n",
-       "      <th>TEMP_MINs</th>\n",
-       "      <th>TEMP_MAXs</th>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>N_PROF</th>\n",
-       "      <th>N_LEVELS</th>\n",
-       "      <th>N_LEVELS_2</th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "      <th></th>\n",
-       "    </tr>\n",
-       "  </thead>\n",
-       "  <tbody>\n",
-       "    <tr>\n",
-       "      <th rowspan=\"11\" valign=\"top\">0</th>\n",
-       "      <th rowspan=\"5\" valign=\"top\">0</th>\n",
-       "      <th>0</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.080000</td>\n",
-       "      <td>...</td>\n",
-       "      <td>21.983</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "      <td>-5.0</td>\n",
-       "      <td>35.612000</td>\n",
-       "      <td>35.410000</td>\n",
-       "      <td>35.736000</td>\n",
-       "      <td>21.983000</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>18.431999</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>1</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.080000</td>\n",
-       "      <td>...</td>\n",
-       "      <td>21.983</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "      <td>20.0</td>\n",
-       "      <td>35.612000</td>\n",
-       "      <td>35.410000</td>\n",
-       "      <td>35.736000</td>\n",
-       "      <td>21.983000</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>2</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.080000</td>\n",
-       "      <td>...</td>\n",
-       "      <td>21.983</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "      <td>20.0</td>\n",
-       "      <td>35.646189</td>\n",
-       "      <td>35.410000</td>\n",
-       "      <td>35.745998</td>\n",
-       "      <td>21.515999</td>\n",
-       "      <td>18.431000</td>\n",
-       "      <td>18.431000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>3</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.080000</td>\n",
-       "      <td>...</td>\n",
-       "      <td>21.983</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "      <td>40.0</td>\n",
-       "      <td>35.646189</td>\n",
-       "      <td>35.410000</td>\n",
-       "      <td>35.745998</td>\n",
-       "      <td>21.515999</td>\n",
-       "      <td>18.431000</td>\n",
-       "      <td>25.010000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>4</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.080000</td>\n",
-       "      <td>...</td>\n",
-       "      <td>21.983</td>\n",
-       "      <td>18.431999</td>\n",
-       "      <td>25.636999</td>\n",
-       "      <td>40.0</td>\n",
-       "      <td>35.677499</td>\n",
-       "      <td>35.419998</td>\n",
-       "      <td>35.744999</td>\n",
-       "      <td>20.744000</td>\n",
-       "      <td>18.405001</td>\n",
-       "      <td>18.405001</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>...</th>\n",
-       "      <th>...</th>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "      <td>...</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th rowspan=\"5\" valign=\"top\">1007</th>\n",
-       "      <th>195</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>2006.079956</td>\n",
-       "      <td>...</td>\n",
-       "      <td>2.387</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "      <td>1960.0</td>\n",
-       "      <td>34.614500</td>\n",
-       "      <td>34.604000</td>\n",
-       "      <td>34.629002</td>\n",
-       "      <td>2.415500</td>\n",
-       "      <td>2.338000</td>\n",
-       "      <td>2.456000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>196</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>2006.079956</td>\n",
-       "      <td>...</td>\n",
-       "      <td>2.387</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "      <td>1960.0</td>\n",
-       "      <td>34.617977</td>\n",
-       "      <td>34.605724</td>\n",
-       "      <td>34.630230</td>\n",
-       "      <td>2.376439</td>\n",
-       "      <td>2.323666</td>\n",
-       "      <td>2.323666</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>197</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>2006.079956</td>\n",
-       "      <td>...</td>\n",
-       "      <td>2.387</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "      <td>1980.0</td>\n",
-       "      <td>34.617977</td>\n",
-       "      <td>34.605724</td>\n",
-       "      <td>34.630230</td>\n",
-       "      <td>2.376439</td>\n",
-       "      <td>2.323666</td>\n",
-       "      <td>2.429212</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>198</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>2006.079956</td>\n",
-       "      <td>...</td>\n",
-       "      <td>2.387</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "      <td>1980.0</td>\n",
-       "      <td>34.618502</td>\n",
-       "      <td>34.608002</td>\n",
-       "      <td>34.632000</td>\n",
-       "      <td>2.387000</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.303000</td>\n",
-       "    </tr>\n",
-       "    <tr>\n",
-       "      <th>199</th>\n",
-       "      <td>5902393.0</td>\n",
-       "      <td>b'S'</td>\n",
-       "      <td>-176.458633</td>\n",
-       "      <td>-29.23719</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>737288.878553</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>1.0</td>\n",
-       "      <td>2006.079956</td>\n",
-       "      <td>...</td>\n",
-       "      <td>2.387</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "      <td>2050.0</td>\n",
-       "      <td>34.618502</td>\n",
-       "      <td>34.608002</td>\n",
-       "      <td>34.632000</td>\n",
-       "      <td>2.387000</td>\n",
-       "      <td>2.303000</td>\n",
-       "      <td>2.458000</td>\n",
-       "    </tr>\n",
-       "  </tbody>\n",
-       "</table>\n",
-       "<p>201600 rows × 28 columns</p>\n",
-       "</div>"
-      ],
-      "text/plain": [
-       "                                   ID PARAM         LON       LAT  POS_QC  \\\n",
-       "N_PROF N_LEVELS N_LEVELS_2                                                  \n",
-       "0      0        0           5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                1           5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                2           5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                3           5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                4           5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "...                               ...   ...         ...       ...     ...   \n",
-       "       1007     195         5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                196         5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                197         5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                198         5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "                199         5902393.0  b'S' -176.458633 -29.23719     1.0   \n",
-       "\n",
-       "                            BAT           JULD  JULD_QC  FALSEorTRUE  \\\n",
-       "N_PROF N_LEVELS N_LEVELS_2                                             \n",
-       "0      0        0           1.0  737288.878553      1.0          1.0   \n",
-       "                1           1.0  737288.878553      1.0          1.0   \n",
-       "                2           1.0  737288.878553      1.0          1.0   \n",
-       "                3           1.0  737288.878553      1.0          1.0   \n",
-       "                4           1.0  737288.878553      1.0          1.0   \n",
-       "...                         ...            ...      ...          ...   \n",
-       "       1007     195         1.0  737288.878553      1.0          1.0   \n",
-       "                196         1.0  737288.878553      1.0          1.0   \n",
-       "                197         1.0  737288.878553      1.0          1.0   \n",
-       "                198         1.0  737288.878553      1.0          1.0   \n",
-       "                199         1.0  737288.878553      1.0          1.0   \n",
-       "\n",
-       "                                   PRES  ...  TEMP_MED   TEMP_MIN   TEMP_MAX  \\\n",
-       "N_PROF N_LEVELS N_LEVELS_2               ...                                   \n",
-       "0      0        0              1.080000  ...    21.983  18.431999  25.636999   \n",
-       "                1              1.080000  ...    21.983  18.431999  25.636999   \n",
-       "                2              1.080000  ...    21.983  18.431999  25.636999   \n",
-       "                3              1.080000  ...    21.983  18.431999  25.636999   \n",
-       "                4              1.080000  ...    21.983  18.431999  25.636999   \n",
-       "...                                 ...  ...       ...        ...        ...   \n",
-       "       1007     195         2006.079956  ...     2.387   2.303000   2.458000   \n",
-       "                196         2006.079956  ...     2.387   2.303000   2.458000   \n",
-       "                197         2006.079956  ...     2.387   2.303000   2.458000   \n",
-       "                198         2006.079956  ...     2.387   2.303000   2.458000   \n",
-       "                199         2006.079956  ...     2.387   2.303000   2.458000   \n",
-       "\n",
-       "                             PRESs  PSAL_MEDs  PSAL_MINs  PSAL_MAXs  \\\n",
-       "N_PROF N_LEVELS N_LEVELS_2                                            \n",
-       "0      0        0             -5.0  35.612000  35.410000  35.736000   \n",
-       "                1             20.0  35.612000  35.410000  35.736000   \n",
-       "                2             20.0  35.646189  35.410000  35.745998   \n",
-       "                3             40.0  35.646189  35.410000  35.745998   \n",
-       "                4             40.0  35.677499  35.419998  35.744999   \n",
-       "...                            ...        ...        ...        ...   \n",
-       "       1007     195         1960.0  34.614500  34.604000  34.629002   \n",
-       "                196         1960.0  34.617977  34.605724  34.630230   \n",
-       "                197         1980.0  34.617977  34.605724  34.630230   \n",
-       "                198         1980.0  34.618502  34.608002  34.632000   \n",
-       "                199         2050.0  34.618502  34.608002  34.632000   \n",
-       "\n",
-       "                            TEMP_MEDs  TEMP_MINs  TEMP_MAXs  \n",
-       "N_PROF N_LEVELS N_LEVELS_2                                   \n",
-       "0      0        0           21.983000  18.431999  18.431999  \n",
-       "                1           21.983000  18.431999  25.636999  \n",
-       "                2           21.515999  18.431000  18.431000  \n",
-       "                3           21.515999  18.431000  25.010000  \n",
-       "                4           20.744000  18.405001  18.405001  \n",
-       "...                               ...        ...        ...  \n",
-       "       1007     195          2.415500   2.338000   2.456000  \n",
-       "                196          2.376439   2.323666   2.323666  \n",
-       "                197          2.376439   2.323666   2.429212  \n",
-       "                198          2.387000   2.303000   2.303000  \n",
-       "                199          2.387000   2.303000   2.458000  \n",
-       "\n",
-       "[201600 rows x 28 columns]"
-      ]
-     },
-     "execution_count": 7,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import xarray as xr\n",
-    "dataframes = []\n",
-    "for filename in os.listdir(\"./dataset\"):\n",
-    "    #file2read = netcdf.NetCDFFile(, 'r')\n",
-    "    ds = xr.open_dataset(\"./dataset/\"+filename)\n",
-    "    df = ds.to_dataframe()\n",
-    "    dataframes.append(df)\n",
-    "    break\n",
-    "pd.set_option('display.max_rows', 100)\n",
-    "df"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 44,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(781,)"
-      ]
-     },
-     "execution_count": 44,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "df['TEMP_MED'].unique().shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 39,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "AttributeError",
-     "evalue": "module 'cf_xarray' has no attribute 'open_dataset'",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[39], line 10\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;66;03m# Apply cf-xarray to decode CF conventions\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m ds_cfxr \u001b[38;5;241m=\u001b[39m \u001b[43mcfxr\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen_dataset\u001b[49m(ds)\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# Convert to pandas DataFrame\u001b[39;00m\n\u001b[1;32m     13\u001b[0m df \u001b[38;5;241m=\u001b[39m ds_cfxr\u001b[38;5;241m.\u001b[39mto_dataframe()\n",
-      "\u001b[0;31mAttributeError\u001b[0m: module 'cf_xarray' has no attribute 'open_dataset'"
-     ]
-    }
-   ],
-   "source": [
-    "import cf_xarray as cfxr\n",
-    "import pandas as pd\n",
-    "\n",
-    "# Open the netCDF file using xarray\n",
-    "for f in os.listdir(\"./dataset\"):\n",
-    "    ds = xr.open_dataset(\"./dataset/\"+f)\n",
-    "    break\n",
-    "\n",
-    "# Apply cf-xarray to decode CF conventions\n",
-    "ds_cfxr = cfxr.open_dataset(ds)\n",
-    "\n",
-    "# Convert to pandas DataFrame\n",
-    "df = ds_cfxr.to_dataframe()\n",
-    "\n",
-    "# Optionally, close the netCDF dataset\n",
-    "ds.close()\n",
-    "\n",
-    "# Now you have the data in a pandas DataFrame\n",
-    "print(df)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 56,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "Frozen({'ID': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'PARAM': <xarray.Variable (N_PROF: 1)> Size: 1B\n",
-       "[1 values with dtype=|S1], 'LON': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'LAT': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'POS_QC': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'BAT': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'JULD': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'JULD_QC': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'FALSEorTRUE': <xarray.Variable (N_PROF: 1)> Size: 8B\n",
-       "[1 values with dtype=float64], 'PRES': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PRES_QC': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PSAL': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PSAL_QC': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PSAL_MED': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PSAL_MIN': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PSAL_MAX': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'TEMP': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'TEMP_QC': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'TEMP_MED': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'TEMP_MIN': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'TEMP_MAX': <xarray.Variable (N_PROF: 1, N_LEVELS: 72)> Size: 576B\n",
-       "[72 values with dtype=float64], 'PRESs': <xarray.Variable (N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'PSAL_MEDs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'PSAL_MINs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'PSAL_MAXs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'TEMP_MEDs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'TEMP_MINs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64], 'TEMP_MAXs': <xarray.Variable (N_PROF: 1, N_LEVELS_2: 200)> Size: 2kB\n",
-       "[200 values with dtype=float64]})"
-      ]
-     },
-     "execution_count": 56,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import xarray as xr\n",
-    "\n",
-    "ds = xr.open_dataset('./dataset/LD_1900973.nc')\n",
-    "ds.variables"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
diff --git a/mlp_train.py b/mlp_train.py
index 6980e9d4b05c8421b47f15aa12817b8acd0df3aa..312e235750273096ff2d1188ad6bc99ac00fb753 100644
--- a/mlp_train.py
+++ b/mlp_train.py
@@ -4,16 +4,9 @@ import numpy as np
 import torch
 from tqdm import trange, tqdm
 from sklearn.model_selection import train_test_split
-import matplotlib.pyplot as plt
 from sklearn.metrics import confusion_matrix
-import seaborn as sns
-from sklearn.metrics import roc_auc_score
 
 
-# S_features_filter = [abs_S_Smin,rel_S_Smin_semi_width,rel_S_Smin_full_width,abs_S_Smax,rel_S_Smax_semi_width,rel_S_Smax_full_width,count_anomalies_S,ratio_anomalies_S,max_variation_S]
-# T_features_filter = [abs_T_Tmin,rel_T_Tmin_semi_width,rel_T_Tmin_full_width,abs_T_Tmax,rel_T_Tmax_semi_width,rel_T_Tmax_full_width,count_anomalies_T,ratio_anomalies_T,max_variation_T]
-# B_features_filter = [mean_correlation,nb_measurements]
-
 S_features_filter = [
     "abs_S_Smin",
     "rel_S_Smin_semi_width",
@@ -39,24 +32,28 @@ T_features_filter = [
 B_features_filter = ["mean_correlation", "nb_measurements"]
 
 PICKLE_PATH = "dataset_pandas/temperature.pkl"
+RANDOM_SEED = 123456789
 
 
 ##### HYPERPARAMETERS #####
-EPOCHS = 300
-BATCH_SIZE = 32
+EPOCHS = 250 #350
+BATCH_SIZE = 32 #16 
 CRITERION = nn.BCELoss()
-OPTIMIZER = torch.optim.Adam
-LEARNING_RATE = 0.01
-GROWTH_RATE = 16
-DROP_RATE = 0.5
-SCHEDULER_PATIENCE = 10
-SCHEDULER_FACTOR = 0.5
+OPTIMIZER = torch.optim.Adam #torch.optim.SGD 
+LEARNING_RATE = 5*1e-3 #*1e-3
+GROWTH_RATE = 32 #16 
+DROP_RATE = 0.5 #0.2
+SCHEDULER_PATIENCE = 10 #15
+SCHEDULER_FACTOR = 0.1
 SCHEDULER_EPS = 1e-8
 
+PLOTS = []
 
-input_features = 11
 
+input_features = 11
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
+##### MLP Definition #####
 class MLP(nn.Module):
     def __init__(self):
         super(MLP, self).__init__()
@@ -101,8 +98,8 @@ class MLP(nn.Module):
         y = torch.sigmoid(y)
         return y
 
-
-def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+########## PREPROCESSING ############
+def prepare_test_train_df(data_seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray]:
     # Load data
     df = pd.read_pickle(PICKLE_PATH)
 
@@ -110,43 +107,62 @@ def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
     df = df.drop(columns=S_features_filter)
 
     # split the into training and testing sets
-    train_df, test_df = train_test_split(df, test_size=0.2, random_state=123456789)
+    train_df, test_df = train_test_split(df, test_size=0.2, random_state=data_seed)
 
     print(
         f"Train alarm distribution (befor undersampling): {train_df['alarm'].value_counts()}"
     )
 
-    # balance the training set
+    # Balance the training set
     min_alarm = train_df["alarm"].value_counts().min()
     train_df = pd.concat(
         [
-            train_df[train_df["alarm"] == 0].sample(min_alarm),
-            train_df[train_df["alarm"] == 1].sample(min_alarm),
+            train_df[train_df["alarm"] == 0].sample(min_alarm, random_state=data_seed),
+            train_df[train_df["alarm"] == 1].sample(min_alarm, random_state=data_seed),
         ]
     )
     print(f"Train alarm distribution: {train_df['alarm'].value_counts()}")
+    return train_df, test_df
 
+def prepare_features_and_labels(train_df, test_df) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+    # Separate train dataset into features and labels
     X_train = train_df.drop(columns=["alarm"]).values
     y_train = train_df["alarm"].values
 
+    # Separate test dataset into features and labels
     X_test = test_df.drop(columns=["alarm"]).values
     y_test = test_df["alarm"].values
 
     return X_train, y_train, X_test, y_test
+    
+    
+########## TRAINING ############
+def train_model(X_train, y_train, X_test, y_test, model_seed = RANDOM_SEED):
 
-
-def train_model(X_train, y_train, X_test, y_test):
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # Fix the model seed for reproducibility
+    torch.manual_seed(model_seed)
 
     # Setting up the data loader
     train_loader = torch.utils.data.DataLoader(
-        list(zip(X_train, y_train)), batch_size=BATCH_SIZE, shuffle=True
+        list(zip(X_train, y_train)),
+        batch_size=BATCH_SIZE,
+        shuffle=True
+    )
+
+    # Setting up the test loader
+    test_loader = torch.utils.data.DataLoader(
+        list(zip(X_test, y_test)),
+        batch_size=BATCH_SIZE,
+        shuffle=False
     )
 
+    # Fix the model seed for reproducibility
+    torch.manual_seed(model_seed)
+
     # Define model
     model = MLP().to(device)
 
+    # Set Dropout rate
     model.dropout = nn.Dropout(DROP_RATE)
 
     # Define loss function
@@ -167,25 +183,33 @@ def train_model(X_train, y_train, X_test, y_test):
     # Train model
     train_losses = []
     test_losses = []
+    train_accuracies = []
+    test_accuracies = []
     for epoch in range(EPOCHS):
         epoch_train_loss = 0
+        total_train = 0
+        correct_train = 0
         with tqdm(train_loader, unit="batch") as t:
             for data, target in t:
                 t.set_description(f"Epoch {str(epoch).rjust(5)}")
 
-                # Move data to device
+                # Move data to device and set model to train mode
                 data, target = data.to(device), target.to(device)
+                model.train()
 
-                # Zero the gradients
+                # Zero the gradients and forward pass
                 optimizer.zero_grad()
-
                 output = model(data.float())
 
                 # Calculate loss
                 loss = criterion(output, target.float().view(-1, 1))
-
                 epoch_train_loss += loss.item()
 
+                # Calculate accuracy
+                y_train_pred_binary = np.where(output.data > 0.5, 1, 0).reshape(1, len(output))[0]
+                correct_train += (target.numpy() == y_train_pred_binary).sum().item()
+                total_train += len(target)
+                
                 # Backpropagation
                 loss.backward()
 
@@ -195,13 +219,24 @@ def train_model(X_train, y_train, X_test, y_test):
                 # Display loss
                 t.set_postfix(train_loss=f"{loss.item():.4f}")
 
-            scheduler.step(loss)
-            # print optimizer learning rate
-            print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
+        # Compute total train accuracy
+        train_accuracy = 100 * correct_train / total_train
+        train_accuracies.append(train_accuracy)
+
+        print(f"Train accuracy: {train_accuracy}")
+
+        # print optimizer learning rate
+        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
 
-            # compute train loss
-            epoch_train_loss /= len(train_loader)
+        # compute train loss
+        epoch_train_loss /= len(train_loader)
 
+        # update scheduler
+        scheduler.step(epoch_train_loss)
+
+        # set model to evaluation mode
+        model.eval()
+        with torch.no_grad():
             # Evaluate model on test set
             y_test_pred = (
                 model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
@@ -211,10 +246,18 @@ def train_model(X_train, y_train, X_test, y_test):
                 torch.tensor(y_test).float().view(-1, 1),
             )
 
+            train_losses.append(epoch_train_loss)
             print(f"Train loss: {epoch_train_loss:.4f}")
-            print(f"Test loss: {test_loss.item():.4f}")
+
+            y_test_pred_binary = np.where(y_test_pred > 0.5, 1, 0).reshape(1, len(y_test_pred))
+            correct_test = (y_test_pred_binary == y_test).sum().item()
+            test_accuracy = 100 * correct_test / len(y_test)
+            test_accuracies.append(test_accuracy)            
+            print(f"Test accuracy: {test_accuracy}")
+
             test_losses.append(test_loss.item())
-            train_losses.append(epoch_train_loss)
+            print(f"Test loss: {test_loss.item():.4f}")
+            
 
             # save model if test loss has decreased
             if len(test_losses) == 1 or test_loss < min(test_losses[:-1]):
@@ -222,69 +265,81 @@ def train_model(X_train, y_train, X_test, y_test):
                     model.state_dict(),
                     f"checkpoints/mlp_{epoch}.pth",
                 )
+            
+    return model, train_losses, test_losses, train_accuracies, test_accuracies
 
-    # Plot losses
-    sns.lineplot(x=range(len(train_losses)), y=train_losses, label="Train loss")
-    sns.lineplot(x=range(len(test_losses)), y=test_losses, label="Test loss")
-    plt.xlabel("Epoch")
-    plt.ylabel("Loss")
-    plt.legend()
-    plt.show()
+########## POST TRAINING & EVALUATION ############
 
-    # load best model from checkpoint
+def load_model_from_checkpoint(model, test_losses):
+    #load best model from checkpoint
     print(f"Loading model from checkpoint: mlp_{np.argmin(test_losses)}.pth")
     model.load_state_dict(torch.load(f"checkpoints/mlp_{np.argmin(test_losses)}.pth"))
-
+    return model
+    
+def get_evaluation(model, X_test):
     # predict on test set
-    y_test_pred = model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
+    model.eval()
+    with torch.no_grad():
+        y_test_pred = (
+            model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
+        )
     y_test_pred_binary = np.where(y_test_pred > 0.5, 1, 0)
 
+    # print parameter count of model
+    # print(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
+
+    return y_test_pred, y_test_pred_binary
+
+def get_confusion_matrix(y_test_pred_binary, y_test):
     # calculate confusion matrix
     cm = confusion_matrix(y_test, y_test_pred_binary)
+    return cm
 
-    print(cm)
-
+def get_metrics(cm):
     # calculate accuracy
     accuracy = np.sum(np.diag(cm)) / np.sum(cm)
     print(f"Accuracy: {accuracy}")
 
-    # print recall
+    # calculate recall
     recall = cm[1, 1] / (cm[1, 0] + cm[1, 1])
     print(f"Recall: {recall}")
 
-    # print precision
+    # calculate precision
     precision = cm[1, 1] / (cm[0, 1] + cm[1, 1])
     print(f"Precision: {precision}")
 
-    # print F1 score
+    # calculate F1 score
     f1 = 2 * (precision * recall) / (precision + recall)
     print(f"F1 score: {f1}")
 
-    # print F2 score
+    # calculate F2 score
     f2 = 5 * (precision * recall) / (4 * precision + recall)
     print(f"F2 score: {f2}")
 
-    # print AUC
-    auc = roc_auc_score(y_test, y_test_pred)
-
-    print(f"AUC: {auc}")
-
-    # plot confusion matrix using seaborn
-    sns.heatmap(cm, annot=True, fmt="d")
-    plt.xlabel("Predicted")
-    plt.ylabel("True")
-    plt.show()
-
-    return model
-
-
-def main():
-    X_train, y_train, X_test, y_test = prepare_data()
-    model = train_model(X_train, y_train, X_test, y_test)
-
-    # print parameter count of model
-    print(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
-
-
-if __name__ == "__main__":
-    main()
+    return accuracy, precision, recall, f1, f2
+
+def sort_testdata_into_cm(test_df, y_test_pred, y_test_pred_binary):
+    # Function to determine TP, TN, FP, FN
+    def determine_result(row):
+        if row['alarm'] == 1 and row['predicted'] == 1:
+            return 'TP'
+        elif row['alarm'] == 0 and row['predicted'] == 0:
+            return 'TN'
+        elif row['alarm'] == 0 and row['predicted'] == 1:
+            return 'FP'
+        elif row['alarm'] == 1 and row['predicted'] == 0:
+            return 'FN'
+    
+    # Sort the predictions into CM categories 
+    y_pred_series = pd.Series(y_test_pred.reshape(1, -1)[0])
+    y_pred_binary_series = pd.Series(y_test_pred_binary.reshape(1, -1)[0])
+    results = test_df.copy()
+    y_pred_series.index = results.index
+    y_pred_binary_series.index = results.index
+    results['predicted'] = y_pred_binary_series.astype(bool)
+    results['predicted_raw'] = y_pred_series
+    
+    # Create new column based on conditions for the confusion matrix
+    results['CM'] = results.apply(determine_result, axis=1)
+
+    return results
\ No newline at end of file
diff --git a/plot_results.py b/plot_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..b79077c296c1e60cb00835af3327afd2c8d3ec7c
--- /dev/null
+++ b/plot_results.py
@@ -0,0 +1,68 @@
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+from sklearn.preprocessing import StandardScaler
+
+
+######### PLOTTING FUNCTIONS ########
+def plot_train_test_evolution(train_losses, test_losses, ax, metric = 'loss'):
+    # Plot losses
+    sns.lineplot(x=range(len(train_losses)), y=train_losses, label=f"Train {metric}", ax=ax)
+    sns.lineplot(x=range(len(test_losses)), y=test_losses, label=f"Test {metric}", ax=ax)
+    ax.set_xlabel("Epoch")
+    ax.set_ylabel("Loss")
+
+def plot_cm(cm, ax):
+    # plot confusion matrix
+    sns.heatmap(cm, annot=True, fmt="d", ax=ax)
+    ax.set_xlabel("Predicted")
+    ax.set_ylabel("True")
+
+def plot_joint_pdf(joint_pdf_train, joint_pdf_test, ax):
+    # Plot the joint PDF along the first dimension)
+    sns.kdeplot(data=joint_pdf_train, label = 'Train data', ax=ax, common_norm = True)
+    sns.kdeplot(data=joint_pdf_test, label = 'Test data', ax=ax, common_norm = True)
+    ax.set_xlabel('X[11 dimensions]')
+    ax.set_ylabel('Density')
+    ax.set_title('Joint PDF of the First Variable')
+    ax.legend()
+
+# This plots the features separately (not sure if we still need this)
+def plot_features(df, error_type, ax):
+    scaler = StandardScaler()
+    df_normalized = pd.DataFrame(scaler.fit_transform(df), columns=df.columns)
+    sns.boxplot(data=df_normalized, ax = ax)
+    ax.set_title(error_type)
+    ax.set_xticks(rotation=90)
+
+###################### LOGGING FUNCTIONS
+# This adds a line in the csv file fo the logs
+def save_metrics(filename, data_seed, model_seed, accuracy, recall, precision, f1, f2):
+    with open(f"./logs/{filename}.csv", "a") as f:
+        f.write(f'{data_seed}, {model_seed}, {accuracy}, {recall}, {precision}, {f1}, {f2}\n')
+
+############# COMPARE OUR ACCURACIES WITH ROMARIC
+def plot_comparison_resultS():
+    FILE_PATHS = ['./logs/random.csv', './logs/fix-data.csv','./logs/fix-model.csv', './logs/fix-data-param-romaric.csv', './logs/fix-model-param-romaric.csv']
+    columns = ['Accuracy'] #['Accuracy', 'Precision', 'Recall']
+    df0 = pd.read_csv(FILE_PATHS[0], header = 0)[columns]
+    df1 = pd.read_csv(FILE_PATHS[1], header = 0)[columns]
+    df2 = pd.read_csv(FILE_PATHS[2], header = 0)[columns]
+    df3 = pd.read_csv(FILE_PATHS[3], header = 0)[columns]
+    df4 = pd.read_csv(FILE_PATHS[4], header = 0)[columns]
+
+    fig, axes = plt.subplots(2, 3, figsize=(12, 10), sharey=True)
+    sns.boxplot(data=df0, ax = axes[0, 0])
+    sns.boxplot(data=df1, ax = axes[0, 1])
+    sns.boxplot(data=df2, ax = axes[0, 2])
+
+    sns.boxplot(data=df3, ax = axes[1, 1])
+    sns.boxplot(data=df4, ax = axes[1, 2])
+    axes[0, 0].set_title('Random Seeds')
+    axes[0, 1].set_title('Fix data, Random mode seeds')
+    axes[0, 2].set_title('Fix model, Random data seeds')
+    axes[1, 1].set_title('Fix data (Romaric)')
+    axes[1, 2].set_title('Fix model (Romaric)')
+    plt.tight_layout()
+    plt.savefig(f'error_distribution.pdf')
+    plt.show()
\ No newline at end of file