Skip to content
Snippets Groups Projects
Commit 4d1c77ee authored by Konstantin Gerd Eyhorn's avatar Konstantin Gerd Eyhorn
Browse files

Merge branch 'net_eval_observations'

parents b5959b77 e791d691
No related branches found
No related tags found
No related merge requests found
import pandas as pd
from sklearn.preprocessing import StandardScaler
from scipy.stats import gaussian_kde
def get_joint_pdf(df_train, df_test):
# Plot the joint PDF of all 10 features
scaler = StandardScaler()
df_train_normalized = pd.DataFrame(scaler.fit_transform(df_train), columns=df_train.columns)
df_test_normalized = pd.DataFrame(scaler.fit_transform(df_test), columns=df_test.columns)
#sns.set_theme(style="white", palette="muted", color_codes=True)
#sns.kdeplot(data=df_normalized, shade=True)
#sns.jointplot(data=df_normalized, kind="kde", height=7, space=0)
kde = gaussian_kde(df_train_normalized.T)
# Evaluate the KDE
joint_pdf_train = kde(df_train_normalized.T)
joint_pdf_test = kde(df_test_normalized.T)
return joint_pdf_train, joint_pdf_test
\ No newline at end of file
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
123456789, 821, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9410755148741418
123456789, 6712, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9445080091533181
123456789, 8255, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
123456789, 4502, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.908466819221968
123456789, 3403, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9199084668192219
123456789, 3008, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9296338672768879
123456789, 123, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9336384439359268
123456789, 6564, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9284897025171625
123456789, 4610, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9290617848970252
123456789, 3914, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9364988558352403
123456789, 2658, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9256292906178489
123456789, 8773, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.937070938215103
123456789, 4258, 0.7619047619047619, 0.9130434782608695, 0.7241379310344828, 0.8076923076923076, 0.8677685950413221, 0.9221967963386727
123456789, 1430, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.933066361556064
123456789, 2979, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9307780320366132
123456789, 3935, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9233409610983982
123456789, 6961, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9290617848970252
123456789, 2381, 0.7738095238095238, 0.9130434782608695, 0.7368421052631579, 0.8155339805825242, 0.871369294605809, 0.9296338672768879
123456789, 4368, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9439359267734554
123456789, 3816, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9336384439359268
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
123456789, 4224, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9290617848970252
123456789, 7729, 0.7857142857142857, 0.9782608695652174, 0.7258064516129032, 0.8333333333333333, 0.9146341463414634, 0.9170480549199085
123456789, 197, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9216247139588101
123456789, 8678, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.927345537757437
123456789, 1418, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9193363844393593
123456789, 1742, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9170480549199085
123456789, 2456, 0.7857142857142857, 0.8913043478260869, 0.7592592592592593, 0.8200000000000001, 0.8613445378151261, 0.9302059496567505
123456789, 5305, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.8844393592677346
123456789, 6591, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9262013729977117
123456789, 4043, 0.75, 0.9130434782608695, 0.711864406779661, 0.7999999999999999, 0.8641975308641975, 0.9090389016018308
123456789, 6737, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9302059496567506
123456789, 3856, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.925629290617849
123456789, 4955, 0.7380952380952381, 0.9130434782608695, 0.7, 0.7924528301886793, 0.8606557377049181, 0.9193363844393593
123456789, 3355, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9124713958810069
123456789, 9015, 0.7142857142857143, 0.9565217391304348, 0.6666666666666666, 0.7857142857142856, 0.88, 0.9204805491990846
123456789, 4165, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9347826086956522
123456789, 4498, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9244851258581236
123456789, 7503, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9279176201372998
123456789, 7120, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.8781464530892449
123456789, 5406, 0.7976190476190477, 0.9130434782608695, 0.7636363636363637, 0.8316831683168316, 0.8786610878661087, 0.9290617848970252
123456789, 1173, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9147597254004577
123456789, 810, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9290617848970252
123456789, 9865, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.8970251716247141
123456789, 5329, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
5517, 123456789, 0.7261904761904762, 0.9, 0.7142857142857143, 0.7964601769911505, 0.8555133079847909, 0.8858823529411763
29, 123456789, 0.7261904761904762, 0.8421052631578947, 0.6530612244897959, 0.735632183908046, 0.7960199004975124, 0.8838672768878718
1651, 123456789, 0.7857142857142857, 0.92, 0.7666666666666667, 0.8363636363636363, 0.8846153846153846, 0.8758823529411764
7768, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8997150997150998
6584, 123456789, 0.7261904761904762, 0.8703703703703703, 0.746031746031746, 0.8034188034188035, 0.8422939068100358, 0.8728395061728396
4261, 123456789, 0.8095238095238095, 0.9019607843137255, 0.8070175438596491, 0.8518518518518519, 0.8812260536398467, 0.9067142008318478
9265, 123456789, 0.7023809523809523, 0.94, 0.6811594202898551, 0.7899159663865546, 0.8736059479553903, 0.8158823529411765
9010, 123456789, 0.7261904761904762, 0.9777777777777777, 0.6666666666666666, 0.7927927927927928, 0.8943089430894309, 0.9287749287749287
1586, 123456789, 0.7619047619047619, 0.9183673469387755, 0.7377049180327869, 0.8181818181818182, 0.8754863813229573, 0.8641399416909622
7263, 123456789, 0.7261904761904762, 0.8444444444444444, 0.7037037037037037, 0.7676767676767676, 0.811965811965812, 0.8655270655270656
4156, 123456789, 0.7380952380952381, 0.9375, 0.703125, 0.8035714285714286, 0.87890625, 0.9184027777777777
7070, 123456789, 0.7142857142857143, 0.9347826086956522, 0.671875, 0.7818181818181819, 0.8669354838709679, 0.8770022883295194
8650, 123456789, 0.7142857142857143, 0.803921568627451, 0.7454545454545455, 0.7735849056603775, 0.7915057915057916, 0.863339275103981
5088, 123456789, 0.5952380952380952, 1.0, 0.569620253164557, 0.7258064516129034, 0.8687258687258687, 0.8119658119658121
9173, 123456789, 0.7023809523809523, 0.84, 0.711864406779661, 0.7706422018348624, 0.8108108108108109, 0.858235294117647
4985, 123456789, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.88558352402746
99, 123456789, 0.7380952380952381, 0.9069767441860465, 0.6842105263157895, 0.78, 0.851528384279476, 0.8774815655133296
9309, 123456789, 0.7142857142857143, 0.9318181818181818, 0.6612903225806451, 0.7735849056603773, 0.861344537815126, 0.86875
9098, 123456789, 0.7738095238095238, 0.8461538461538461, 0.8, 0.8224299065420562, 0.8365019011406846, 0.9038461538461539
3474, 123456789, 0.6309523809523809, 0.82, 0.6507936507936508, 0.7256637168141592, 0.779467680608365, 0.8188235294117647
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
2659, 123456789, 0.7619047619047619, 0.86, 0.7678571428571429, 0.8113207547169812, 0.83984375, 0.8870588235294117
5355, 123456789, 0.7738095238095238, 0.9111111111111111, 0.7321428571428571, 0.8118811881188118, 0.8686440677966102, 0.9059829059829061
6402, 123456789, 0.8095238095238095, 0.9230769230769231, 0.8, 0.8571428571428571, 0.8955223880597014, 0.907451923076923
8817, 123456789, 0.7380952380952381, 1.0, 0.6271186440677966, 0.7708333333333333, 0.8937198067632851, 0.9355951696377227
7097, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8831908831908832
3543, 123456789, 0.7976190476190477, 0.9056603773584906, 0.8, 0.8495575221238938, 0.8823529411764706, 0.8746195982958004
3630, 123456789, 0.7261904761904762, 0.8, 0.7547169811320755, 0.7766990291262137, 0.790513833992095, 0.8400000000000001
481, 123456789, 0.7619047619047619, 0.8297872340425532, 0.7647058823529411, 0.7959183673469387, 0.8158995815899581, 0.8964922369177688
8874, 123456789, 0.8214285714285714, 0.9555555555555556, 0.7678571428571429, 0.8514851485148515, 0.9110169491525425, 0.8968660968660969
2423, 123456789, 0.75, 0.9, 0.7377049180327869, 0.8108108108108109, 0.8620689655172414, 0.8847058823529411
572, 123456789, 0.7380952380952381, 0.8775510204081632, 0.7288135593220338, 0.7962962962962963, 0.8431372549019608, 0.8256559766763849
9467, 123456789, 0.7857142857142857, 0.74, 0.8809523809523809, 0.8043478260869565, 0.7644628099173554, 0.91
4669, 123456789, 0.75, 0.8703703703703703, 0.7704918032786885, 0.817391304347826, 0.848375451263538, 0.880246913580247
3840, 123456789, 0.7380952380952381, 0.9347826086956522, 0.6935483870967742, 0.7962962962962964, 0.8739837398373984, 0.8827231121281465
717, 123456789, 0.7380952380952381, 0.8636363636363636, 0.7037037037037037, 0.7755102040816326, 0.8260869565217391, 0.8522727272727272
4353, 123456789, 0.7738095238095238, 0.9782608695652174, 0.7142857142857143, 0.8256880733944955, 0.9109311740890689, 0.9399313501144165
7024, 123456789, 0.7619047619047619, 0.8043478260869565, 0.7708333333333334, 0.7872340425531915, 0.7974137931034484, 0.8775743707093822
5849, 123456789, 0.7857142857142857, 0.96, 0.75, 0.8421052631578947, 0.9090909090909091, 0.9252941176470586
7847, 123456789, 0.7380952380952381, 0.8095238095238095, 0.7083333333333334, 0.7555555555555556, 0.7870370370370371, 0.8781179138321996
6263, 123456789, 0.8095238095238095, 0.9137931034482759, 0.828125, 0.8688524590163935, 0.8952702702702704, 0.8925729442970822
107, 123456789, 0.7380952380952381, 0.9166666666666666, 0.7096774193548387, 0.7999999999999999, 0.8661417322834645, 0.8790509259259259
Run,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
8277, 0.6785714285714286, 0.8181818181818182, 0.6545454545454545, 0.7272727272727274, 0.7792207792207791, 0.8636363636363636
9552, 0.6309523809523809, 0.9142857142857143, 0.5333333333333333, 0.6736842105263158, 0.8000000000000002, 0.8478134110787172
8712, 0.6666666666666666, 0.8181818181818182, 0.6428571428571429, 0.7200000000000001, 0.7758620689655173, 0.8045454545454546
1418, 0.7142857142857143, 0.8235294117647058, 0.7368421052631579, 0.7777777777777778, 0.8045977011494252, 0.8354129530600121
3720, 0.7380952380952381, 0.82, 0.7592592592592593, 0.7884615384615384, 0.8070866141732284, 0.7582352941176471
9510, 0.8214285714285714, 0.8936170212765957, 0.8076923076923077, 0.8484848484848485, 0.875, 0.8861414606095457
2, 0.6547619047619048, 0.9166666666666666, 0.6376811594202898, 0.752136752136752, 0.842911877394636, 0.7777777777777777
209, 0.7619047619047619, 0.9069767441860465, 0.7090909090909091, 0.7959183673469388, 0.8590308370044054, 0.8740782756664776
1000, 0.7023809523809523, 0.8222222222222222, 0.6851851851851852, 0.7474747474747475, 0.7905982905982906, 0.8051282051282052
9103, 0.75, 0.8936170212765957, 0.7241379310344828, 0.7999999999999999, 0.8536585365853658, 0.8959171937895342
7355, 0.75, 0.868421052631579, 0.673469387755102, 0.7586206896551724, 0.8208955223880596, 0.8975972540045767
main.py 0 → 100644
import mlp_train as m
import plot_results as p
import analysis as a
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def main():
for data_seed in np.random.randint(0, 10000, size=10):
model_seed = m.RANDOM_SEED
train_df, test_df = m.prepare_test_train_df(data_seed)
X_train, y_train, X_test, y_test = m.prepare_features_and_labels(train_df, test_df)
joint_pdf_train, joint_pdf_test = a.get_joint_pdf(pd.DataFrame(X_train), pd.DataFrame(X_test))
model, train_losses, test_losses, train_accuracies, test_accuracies = m.train_model(X_train, y_train, X_test, y_test, model_seed)
y_test_pred, y_test_pred_binary = m.get_evaluation(model, X_test)
results = m.sort_testdata_into_cm(test_df, y_test_pred, y_test_pred_binary)
fig, ax = plt.subplots(2, 2, figsize=(24, 6))
p.plot_joint_pdf(joint_pdf_train, joint_pdf_test, ax[0, 0])
p.plot_train_test_evolution(train_accuracies, test_accuracies, ax[0, 1], metric = 'accuracy')
p.plot_train_test_evolution(train_losses, test_losses, ax[1, 0])
sns.kdeplot(data = results, x = 'predicted_raw', hue = "CM", ax=ax[1, 1])
#plt.savefig(f"./results/results_data{data_seed}_model{model_seed}.png")
plt.show()
if __name__ == "__main__":
main()
\ No newline at end of file
This diff is collapsed.
......@@ -4,16 +4,9 @@ import numpy as np
import torch
from tqdm import trange, tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.metrics import roc_auc_score
# S_features_filter = [abs_S_Smin,rel_S_Smin_semi_width,rel_S_Smin_full_width,abs_S_Smax,rel_S_Smax_semi_width,rel_S_Smax_full_width,count_anomalies_S,ratio_anomalies_S,max_variation_S]
# T_features_filter = [abs_T_Tmin,rel_T_Tmin_semi_width,rel_T_Tmin_full_width,abs_T_Tmax,rel_T_Tmax_semi_width,rel_T_Tmax_full_width,count_anomalies_T,ratio_anomalies_T,max_variation_T]
# B_features_filter = [mean_correlation,nb_measurements]
S_features_filter = [
"abs_S_Smin",
"rel_S_Smin_semi_width",
......@@ -39,24 +32,28 @@ T_features_filter = [
B_features_filter = ["mean_correlation", "nb_measurements"]
PICKLE_PATH = "dataset_pandas/temperature.pkl"
RANDOM_SEED = 123456789
##### HYPERPARAMETERS #####
EPOCHS = 300
BATCH_SIZE = 32
EPOCHS = 250 #350
BATCH_SIZE = 32 #16
CRITERION = nn.BCELoss()
OPTIMIZER = torch.optim.Adam
LEARNING_RATE = 0.01
GROWTH_RATE = 16
DROP_RATE = 0.5
SCHEDULER_PATIENCE = 10
SCHEDULER_FACTOR = 0.5
OPTIMIZER = torch.optim.Adam #torch.optim.SGD
LEARNING_RATE = 5*1e-3 #*1e-3
GROWTH_RATE = 32 #16
DROP_RATE = 0.5 #0.2
SCHEDULER_PATIENCE = 10 #15
SCHEDULER_FACTOR = 0.1
SCHEDULER_EPS = 1e-8
PLOTS = []
input_features = 11
input_features = 11
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##### MLP Definition #####
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
......@@ -101,8 +98,8 @@ class MLP(nn.Module):
y = torch.sigmoid(y)
return y
def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
########## PREPROCESSING ############
def prepare_test_train_df(data_seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray]:
# Load data
df = pd.read_pickle(PICKLE_PATH)
......@@ -110,43 +107,62 @@ def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
df = df.drop(columns=S_features_filter)
# split the into training and testing sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=123456789)
train_df, test_df = train_test_split(df, test_size=0.2, random_state=data_seed)
print(
f"Train alarm distribution (befor undersampling): {train_df['alarm'].value_counts()}"
)
# balance the training set
# Balance the training set
min_alarm = train_df["alarm"].value_counts().min()
train_df = pd.concat(
[
train_df[train_df["alarm"] == 0].sample(min_alarm),
train_df[train_df["alarm"] == 1].sample(min_alarm),
train_df[train_df["alarm"] == 0].sample(min_alarm, random_state=data_seed),
train_df[train_df["alarm"] == 1].sample(min_alarm, random_state=data_seed),
]
)
print(f"Train alarm distribution: {train_df['alarm'].value_counts()}")
return train_df, test_df
def prepare_features_and_labels(train_df, test_df) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# Separate train dataset into features and labels
X_train = train_df.drop(columns=["alarm"]).values
y_train = train_df["alarm"].values
# Separate test dataset into features and labels
X_test = test_df.drop(columns=["alarm"]).values
y_test = test_df["alarm"].values
return X_train, y_train, X_test, y_test
########## TRAINING ############
def train_model(X_train, y_train, X_test, y_test, model_seed = RANDOM_SEED):
def train_model(X_train, y_train, X_test, y_test):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Fix the model seed for reproducibility
torch.manual_seed(model_seed)
# Setting up the data loader
train_loader = torch.utils.data.DataLoader(
list(zip(X_train, y_train)), batch_size=BATCH_SIZE, shuffle=True
list(zip(X_train, y_train)),
batch_size=BATCH_SIZE,
shuffle=True
)
# Setting up the test loader
test_loader = torch.utils.data.DataLoader(
list(zip(X_test, y_test)),
batch_size=BATCH_SIZE,
shuffle=False
)
# Fix the model seed for reproducibility
torch.manual_seed(model_seed)
# Define model
model = MLP().to(device)
# Set Dropout rate
model.dropout = nn.Dropout(DROP_RATE)
# Define loss function
......@@ -167,25 +183,33 @@ def train_model(X_train, y_train, X_test, y_test):
# Train model
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
for epoch in range(EPOCHS):
epoch_train_loss = 0
total_train = 0
correct_train = 0
with tqdm(train_loader, unit="batch") as t:
for data, target in t:
t.set_description(f"Epoch {str(epoch).rjust(5)}")
# Move data to device
# Move data to device and set model to train mode
data, target = data.to(device), target.to(device)
model.train()
# Zero the gradients
# Zero the gradients and forward pass
optimizer.zero_grad()
output = model(data.float())
# Calculate loss
loss = criterion(output, target.float().view(-1, 1))
epoch_train_loss += loss.item()
# Calculate accuracy
y_train_pred_binary = np.where(output.data > 0.5, 1, 0).reshape(1, len(output))[0]
correct_train += (target.numpy() == y_train_pred_binary).sum().item()
total_train += len(target)
# Backpropagation
loss.backward()
......@@ -195,13 +219,24 @@ def train_model(X_train, y_train, X_test, y_test):
# Display loss
t.set_postfix(train_loss=f"{loss.item():.4f}")
scheduler.step(loss)
# print optimizer learning rate
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
# Compute total train accuracy
train_accuracy = 100 * correct_train / total_train
train_accuracies.append(train_accuracy)
print(f"Train accuracy: {train_accuracy}")
# print optimizer learning rate
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
# compute train loss
epoch_train_loss /= len(train_loader)
# compute train loss
epoch_train_loss /= len(train_loader)
# update scheduler
scheduler.step(epoch_train_loss)
# set model to evaluation mode
model.eval()
with torch.no_grad():
# Evaluate model on test set
y_test_pred = (
model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
......@@ -211,10 +246,18 @@ def train_model(X_train, y_train, X_test, y_test):
torch.tensor(y_test).float().view(-1, 1),
)
train_losses.append(epoch_train_loss)
print(f"Train loss: {epoch_train_loss:.4f}")
print(f"Test loss: {test_loss.item():.4f}")
y_test_pred_binary = np.where(y_test_pred > 0.5, 1, 0).reshape(1, len(y_test_pred))
correct_test = (y_test_pred_binary == y_test).sum().item()
test_accuracy = 100 * correct_test / len(y_test)
test_accuracies.append(test_accuracy)
print(f"Test accuracy: {test_accuracy}")
test_losses.append(test_loss.item())
train_losses.append(epoch_train_loss)
print(f"Test loss: {test_loss.item():.4f}")
# save model if test loss has decreased
if len(test_losses) == 1 or test_loss < min(test_losses[:-1]):
......@@ -222,69 +265,81 @@ def train_model(X_train, y_train, X_test, y_test):
model.state_dict(),
f"checkpoints/mlp_{epoch}.pth",
)
return model, train_losses, test_losses, train_accuracies, test_accuracies
# Plot losses
sns.lineplot(x=range(len(train_losses)), y=train_losses, label="Train loss")
sns.lineplot(x=range(len(test_losses)), y=test_losses, label="Test loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
########## POST TRAINING & EVALUATION ############
# load best model from checkpoint
def load_model_from_checkpoint(model, test_losses):
#load best model from checkpoint
print(f"Loading model from checkpoint: mlp_{np.argmin(test_losses)}.pth")
model.load_state_dict(torch.load(f"checkpoints/mlp_{np.argmin(test_losses)}.pth"))
return model
def get_evaluation(model, X_test):
# predict on test set
y_test_pred = model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
model.eval()
with torch.no_grad():
y_test_pred = (
model(torch.tensor(X_test).float().to(device)).cpu().detach().numpy()
)
y_test_pred_binary = np.where(y_test_pred > 0.5, 1, 0)
# print parameter count of model
# print(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
return y_test_pred, y_test_pred_binary
def get_confusion_matrix(y_test_pred_binary, y_test):
# calculate confusion matrix
cm = confusion_matrix(y_test, y_test_pred_binary)
return cm
print(cm)
def get_metrics(cm):
# calculate accuracy
accuracy = np.sum(np.diag(cm)) / np.sum(cm)
print(f"Accuracy: {accuracy}")
# print recall
# calculate recall
recall = cm[1, 1] / (cm[1, 0] + cm[1, 1])
print(f"Recall: {recall}")
# print precision
# calculate precision
precision = cm[1, 1] / (cm[0, 1] + cm[1, 1])
print(f"Precision: {precision}")
# print F1 score
# calculate F1 score
f1 = 2 * (precision * recall) / (precision + recall)
print(f"F1 score: {f1}")
# print F2 score
# calculate F2 score
f2 = 5 * (precision * recall) / (4 * precision + recall)
print(f"F2 score: {f2}")
# print AUC
auc = roc_auc_score(y_test, y_test_pred)
print(f"AUC: {auc}")
# plot confusion matrix using seaborn
sns.heatmap(cm, annot=True, fmt="d")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
return model
def main():
X_train, y_train, X_test, y_test = prepare_data()
model = train_model(X_train, y_train, X_test, y_test)
# print parameter count of model
print(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
if __name__ == "__main__":
main()
return accuracy, precision, recall, f1, f2
def sort_testdata_into_cm(test_df, y_test_pred, y_test_pred_binary):
# Function to determine TP, TN, FP, FN
def determine_result(row):
if row['alarm'] == 1 and row['predicted'] == 1:
return 'TP'
elif row['alarm'] == 0 and row['predicted'] == 0:
return 'TN'
elif row['alarm'] == 0 and row['predicted'] == 1:
return 'FP'
elif row['alarm'] == 1 and row['predicted'] == 0:
return 'FN'
# Sort the predictions into CM categories
y_pred_series = pd.Series(y_test_pred.reshape(1, -1)[0])
y_pred_binary_series = pd.Series(y_test_pred_binary.reshape(1, -1)[0])
results = test_df.copy()
y_pred_series.index = results.index
y_pred_binary_series.index = results.index
results['predicted'] = y_pred_binary_series.astype(bool)
results['predicted_raw'] = y_pred_series
# Create new column based on conditions for the confusion matrix
results['CM'] = results.apply(determine_result, axis=1)
return results
\ No newline at end of file
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
######### PLOTTING FUNCTIONS ########
def plot_train_test_evolution(train_losses, test_losses, ax, metric = 'loss'):
# Plot losses
sns.lineplot(x=range(len(train_losses)), y=train_losses, label=f"Train {metric}", ax=ax)
sns.lineplot(x=range(len(test_losses)), y=test_losses, label=f"Test {metric}", ax=ax)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
def plot_cm(cm, ax):
# plot confusion matrix
sns.heatmap(cm, annot=True, fmt="d", ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
def plot_joint_pdf(joint_pdf_train, joint_pdf_test, ax):
# Plot the joint PDF along the first dimension)
sns.kdeplot(data=joint_pdf_train, label = 'Train data', ax=ax, common_norm = True)
sns.kdeplot(data=joint_pdf_test, label = 'Test data', ax=ax, common_norm = True)
ax.set_xlabel('X[11 dimensions]')
ax.set_ylabel('Density')
ax.set_title('Joint PDF of the First Variable')
ax.legend()
# This plots the features separately (not sure if we still need this)
def plot_features(df, error_type, ax):
scaler = StandardScaler()
df_normalized = pd.DataFrame(scaler.fit_transform(df), columns=df.columns)
sns.boxplot(data=df_normalized, ax = ax)
ax.set_title(error_type)
ax.set_xticks(rotation=90)
###################### LOGGING FUNCTIONS
# This adds a line in the csv file fo the logs
def save_metrics(filename, data_seed, model_seed, accuracy, recall, precision, f1, f2):
with open(f"./logs/{filename}.csv", "a") as f:
f.write(f'{data_seed}, {model_seed}, {accuracy}, {recall}, {precision}, {f1}, {f2}\n')
############# COMPARE OUR ACCURACIES WITH ROMARIC
def plot_comparison_resultS():
FILE_PATHS = ['./logs/random.csv', './logs/fix-data.csv','./logs/fix-model.csv', './logs/fix-data-param-romaric.csv', './logs/fix-model-param-romaric.csv']
columns = ['Accuracy'] #['Accuracy', 'Precision', 'Recall']
df0 = pd.read_csv(FILE_PATHS[0], header = 0)[columns]
df1 = pd.read_csv(FILE_PATHS[1], header = 0)[columns]
df2 = pd.read_csv(FILE_PATHS[2], header = 0)[columns]
df3 = pd.read_csv(FILE_PATHS[3], header = 0)[columns]
df4 = pd.read_csv(FILE_PATHS[4], header = 0)[columns]
fig, axes = plt.subplots(2, 3, figsize=(12, 10), sharey=True)
sns.boxplot(data=df0, ax = axes[0, 0])
sns.boxplot(data=df1, ax = axes[0, 1])
sns.boxplot(data=df2, ax = axes[0, 2])
sns.boxplot(data=df3, ax = axes[1, 1])
sns.boxplot(data=df4, ax = axes[1, 2])
axes[0, 0].set_title('Random Seeds')
axes[0, 1].set_title('Fix data, Random mode seeds')
axes[0, 2].set_title('Fix model, Random data seeds')
axes[1, 1].set_title('Fix data (Romaric)')
axes[1, 2].set_title('Fix model (Romaric)')
plt.tight_layout()
plt.savefig(f'error_distribution.pdf')
plt.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment