Skip to content
Snippets Groups Projects
Commit f70f4369 authored by Boshra Ariguib's avatar Boshra Ariguib
Browse files

added plots about results

parent 9daba7f5
No related branches found
No related tags found
No related merge requests found
Showing
with 278 additions and 75 deletions
File added
Run, Accuracy, Recall, Precision, F1 Score, F2 Score, AUC
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.919908466819222
0, 0.7619047619047619, 0.8695652173913043, 0.7407407407407407, 0.7999999999999999, 0.8403361344537816, 0.9176201372997711
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9181922196796339
0, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.9256292906178489
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.927345537757437
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.914187643020595
0, 0.8690476190476191, 0.8695652173913043, 0.8888888888888888, 0.8791208791208792, 0.8733624454148471, 0.9267734553775744
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9067505720823799
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9319221967963387
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9221967963386728
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.919908466819222
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9187643020594966
0, 0.7619047619047619, 0.8913043478260869, 0.7321428571428571, 0.8039215686274508, 0.8541666666666666, 0.9136155606407323
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9273455377574371
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9153318077803204
0, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.9176201372997712
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9204805491990848
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9227688787185354
\ No newline at end of file
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.891304347826087
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.8775743707093822
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9250572082379862
0, 0.7976190476190477, 0.9130434782608695, 0.7636363636363637, 0.8316831683168316, 0.8786610878661087, 0.9250572082379863
0, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.8758581235697941
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.8981693363844393
0, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9330663615560641
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9244851258581237
0, 0.7738095238095238, 0.5869565217391305, 1.0, 0.7397260273972603, 0.6398104265402843, 0.9067505720823799
0, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9382151029748284
Run, Accuracy, Recall, Precision, F1 Score, F2 Score, AUC
0, 0.7619047619047619, 0.9090909090909091, 0.7142857142857143, 0.8, 0.8620689655172413, 0.91875
0, 0.7738095238095238, 0.9024390243902439, 0.7115384615384616, 0.7956989247311829, 0.8564814814814816, 0.9160521837776517
0, 0.75, 0.9411764705882353, 0.7272727272727273, 0.8205128205128205, 0.8888888888888888, 0.803921568627451
0, 0.7738095238095238, 0.9148936170212766, 0.7413793103448276, 0.819047619047619, 0.8739837398373983, 0.885566417481311
0, 0.7738095238095238, 0.8076923076923077, 0.8235294117647058, 0.8155339805825242, 0.8108108108108109, 0.8311298076923077
0, 0.7023809523809523, 0.8478260869565217, 0.6842105263157895, 0.7572815533980581, 0.8091286307053941, 0.8707093821510298
0, 0.75, 0.9347826086956522, 0.7049180327868853, 0.8037383177570093, 0.8775510204081632, 0.9250572082379863
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.908466819221968
0, 0.75, 0.9024390243902439, 0.6851851851851852, 0.7789473684210526, 0.8486238532110092, 0.8865570051049347
0, 0.8214285714285714, 0.9347826086956522, 0.7818181818181819, 0.8514851485148516, 0.8995815899581592, 0.9164759725400459
0, 0.7380952380952381, 0.88, 0.7333333333333333, 0.8, 0.8461538461538461, 0.8923529411764706
0, 0.75, 0.8775510204081632, 0.7413793103448276, 0.8037383177570093, 0.8464566929133858, 0.8967930029154519
0, 0.6547619047619048, 0.7906976744186046, 0.6296296296296297, 0.7010309278350516, 0.7522123893805309, 0.8077141236528644
0, 0.8095238095238095, 0.8823529411764706, 0.8181818181818182, 0.8490566037735848, 0.8687258687258687, 0.8835412953060012
0, 0.7738095238095238, 0.8775510204081632, 0.7678571428571429, 0.819047619047619, 0.8531746031746031, 0.8594752186588921
0, 0.7261904761904762, 0.84, 0.7368421052631579, 0.7850467289719626, 0.8171206225680934, 0.8664705882352941
0, 0.8214285714285714, 0.9230769230769231, 0.8135593220338984, 0.8648648648648649, 0.8988764044943821, 0.9122596153846154
0, 0.6547619047619048, 0.7272727272727273, 0.6530612244897959, 0.6881720430107526, 0.7111111111111111, 0.7704545454545455
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
123456789, 821, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9410755148741418
123456789, 6712, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9445080091533181
123456789, 8255, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
123456789, 4502, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.908466819221968
123456789, 3403, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9199084668192219
123456789, 3008, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9296338672768879
123456789, 123, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9336384439359268
123456789, 6564, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9284897025171625
123456789, 4610, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9290617848970252
123456789, 3914, 0.7738095238095238, 0.9565217391304348, 0.7213114754098361, 0.822429906542056, 0.8979591836734693, 0.9364988558352403
123456789, 2658, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9256292906178489
123456789, 8773, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.937070938215103
123456789, 4258, 0.7619047619047619, 0.9130434782608695, 0.7241379310344828, 0.8076923076923076, 0.8677685950413221, 0.9221967963386727
123456789, 1430, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.933066361556064
123456789, 2979, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9307780320366132
123456789, 3935, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9233409610983982
123456789, 6961, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9290617848970252
123456789, 2381, 0.7738095238095238, 0.9130434782608695, 0.7368421052631579, 0.8155339805825242, 0.871369294605809, 0.9296338672768879
123456789, 4368, 0.7619047619047619, 0.9347826086956522, 0.7166666666666667, 0.8113207547169811, 0.8811475409836066, 0.9439359267734554
123456789, 3816, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.9336384439359268
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
123456789, 4224, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9290617848970252
123456789, 7729, 0.7857142857142857, 0.9782608695652174, 0.7258064516129032, 0.8333333333333333, 0.9146341463414634, 0.9170480549199085
123456789, 197, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9216247139588101
123456789, 8678, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.927345537757437
123456789, 1418, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9193363844393593
123456789, 1742, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9170480549199085
123456789, 2456, 0.7857142857142857, 0.8913043478260869, 0.7592592592592593, 0.8200000000000001, 0.8613445378151261, 0.9302059496567505
123456789, 5305, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.8844393592677346
123456789, 6591, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9262013729977117
123456789, 4043, 0.75, 0.9130434782608695, 0.711864406779661, 0.7999999999999999, 0.8641975308641975, 0.9090389016018308
123456789, 6737, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9302059496567506
123456789, 3856, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.925629290617849
123456789, 4955, 0.7380952380952381, 0.9130434782608695, 0.7, 0.7924528301886793, 0.8606557377049181, 0.9193363844393593
123456789, 3355, 0.7738095238095238, 0.9347826086956522, 0.7288135593220338, 0.819047619047619, 0.8847736625514403, 0.9124713958810069
123456789, 9015, 0.7142857142857143, 0.9565217391304348, 0.6666666666666666, 0.7857142857142856, 0.88, 0.9204805491990846
123456789, 4165, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9347826086956522
123456789, 4498, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9244851258581236
123456789, 7503, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9279176201372998
123456789, 7120, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.8781464530892449
123456789, 5406, 0.7976190476190477, 0.9130434782608695, 0.7636363636363637, 0.8316831683168316, 0.8786610878661087, 0.9290617848970252
123456789, 1173, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.9147597254004577
123456789, 810, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9290617848970252
123456789, 9865, 0.7857142857142857, 0.9347826086956522, 0.7413793103448276, 0.826923076923077, 0.8884297520661157, 0.8970251716247141
123456789, 5329, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.931350114416476
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
5517, 123456789, 0.7261904761904762, 0.9, 0.7142857142857143, 0.7964601769911505, 0.8555133079847909, 0.8858823529411763
29, 123456789, 0.7261904761904762, 0.8421052631578947, 0.6530612244897959, 0.735632183908046, 0.7960199004975124, 0.8838672768878718
1651, 123456789, 0.7857142857142857, 0.92, 0.7666666666666667, 0.8363636363636363, 0.8846153846153846, 0.8758823529411764
7768, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8997150997150998
6584, 123456789, 0.7261904761904762, 0.8703703703703703, 0.746031746031746, 0.8034188034188035, 0.8422939068100358, 0.8728395061728396
4261, 123456789, 0.8095238095238095, 0.9019607843137255, 0.8070175438596491, 0.8518518518518519, 0.8812260536398467, 0.9067142008318478
9265, 123456789, 0.7023809523809523, 0.94, 0.6811594202898551, 0.7899159663865546, 0.8736059479553903, 0.8158823529411765
9010, 123456789, 0.7261904761904762, 0.9777777777777777, 0.6666666666666666, 0.7927927927927928, 0.8943089430894309, 0.9287749287749287
1586, 123456789, 0.7619047619047619, 0.9183673469387755, 0.7377049180327869, 0.8181818181818182, 0.8754863813229573, 0.8641399416909622
7263, 123456789, 0.7261904761904762, 0.8444444444444444, 0.7037037037037037, 0.7676767676767676, 0.811965811965812, 0.8655270655270656
4156, 123456789, 0.7380952380952381, 0.9375, 0.703125, 0.8035714285714286, 0.87890625, 0.9184027777777777
7070, 123456789, 0.7142857142857143, 0.9347826086956522, 0.671875, 0.7818181818181819, 0.8669354838709679, 0.8770022883295194
8650, 123456789, 0.7142857142857143, 0.803921568627451, 0.7454545454545455, 0.7735849056603775, 0.7915057915057916, 0.863339275103981
5088, 123456789, 0.5952380952380952, 1.0, 0.569620253164557, 0.7258064516129034, 0.8687258687258687, 0.8119658119658121
9173, 123456789, 0.7023809523809523, 0.84, 0.711864406779661, 0.7706422018348624, 0.8108108108108109, 0.858235294117647
4985, 123456789, 0.7619047619047619, 0.9565217391304348, 0.7096774193548387, 0.8148148148148149, 0.894308943089431, 0.88558352402746
99, 123456789, 0.7380952380952381, 0.9069767441860465, 0.6842105263157895, 0.78, 0.851528384279476, 0.8774815655133296
9309, 123456789, 0.7142857142857143, 0.9318181818181818, 0.6612903225806451, 0.7735849056603773, 0.861344537815126, 0.86875
9098, 123456789, 0.7738095238095238, 0.8461538461538461, 0.8, 0.8224299065420562, 0.8365019011406846, 0.9038461538461539
3474, 123456789, 0.6309523809523809, 0.82, 0.6507936507936508, 0.7256637168141592, 0.779467680608365, 0.8188235294117647
DataSeed,ModelSeed,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
2659, 123456789, 0.7619047619047619, 0.86, 0.7678571428571429, 0.8113207547169812, 0.83984375, 0.8870588235294117
5355, 123456789, 0.7738095238095238, 0.9111111111111111, 0.7321428571428571, 0.8118811881188118, 0.8686440677966102, 0.9059829059829061
6402, 123456789, 0.8095238095238095, 0.9230769230769231, 0.8, 0.8571428571428571, 0.8955223880597014, 0.907451923076923
8817, 123456789, 0.7380952380952381, 1.0, 0.6271186440677966, 0.7708333333333333, 0.8937198067632851, 0.9355951696377227
7097, 123456789, 0.7619047619047619, 0.9111111111111111, 0.7192982456140351, 0.803921568627451, 0.8649789029535865, 0.8831908831908832
3543, 123456789, 0.7976190476190477, 0.9056603773584906, 0.8, 0.8495575221238938, 0.8823529411764706, 0.8746195982958004
3630, 123456789, 0.7261904761904762, 0.8, 0.7547169811320755, 0.7766990291262137, 0.790513833992095, 0.8400000000000001
481, 123456789, 0.7619047619047619, 0.8297872340425532, 0.7647058823529411, 0.7959183673469387, 0.8158995815899581, 0.8964922369177688
8874, 123456789, 0.8214285714285714, 0.9555555555555556, 0.7678571428571429, 0.8514851485148515, 0.9110169491525425, 0.8968660968660969
2423, 123456789, 0.75, 0.9, 0.7377049180327869, 0.8108108108108109, 0.8620689655172414, 0.8847058823529411
572, 123456789, 0.7380952380952381, 0.8775510204081632, 0.7288135593220338, 0.7962962962962963, 0.8431372549019608, 0.8256559766763849
9467, 123456789, 0.7857142857142857, 0.74, 0.8809523809523809, 0.8043478260869565, 0.7644628099173554, 0.91
4669, 123456789, 0.75, 0.8703703703703703, 0.7704918032786885, 0.817391304347826, 0.848375451263538, 0.880246913580247
3840, 123456789, 0.7380952380952381, 0.9347826086956522, 0.6935483870967742, 0.7962962962962964, 0.8739837398373984, 0.8827231121281465
717, 123456789, 0.7380952380952381, 0.8636363636363636, 0.7037037037037037, 0.7755102040816326, 0.8260869565217391, 0.8522727272727272
4353, 123456789, 0.7738095238095238, 0.9782608695652174, 0.7142857142857143, 0.8256880733944955, 0.9109311740890689, 0.9399313501144165
7024, 123456789, 0.7619047619047619, 0.8043478260869565, 0.7708333333333334, 0.7872340425531915, 0.7974137931034484, 0.8775743707093822
5849, 123456789, 0.7857142857142857, 0.96, 0.75, 0.8421052631578947, 0.9090909090909091, 0.9252941176470586
7847, 123456789, 0.7380952380952381, 0.8095238095238095, 0.7083333333333334, 0.7555555555555556, 0.7870370370370371, 0.8781179138321996
6263, 123456789, 0.8095238095238095, 0.9137931034482759, 0.828125, 0.8688524590163935, 0.8952702702702704, 0.8925729442970822
107, 123456789, 0.7380952380952381, 0.9166666666666666, 0.7096774193548387, 0.7999999999999999, 0.8661417322834645, 0.8790509259259259
Run,Accuracy,Recall,Precision,F1 Score,F2 Score,AUC
8277, 0.6785714285714286, 0.8181818181818182, 0.6545454545454545, 0.7272727272727274, 0.7792207792207791, 0.8636363636363636
9552, 0.6309523809523809, 0.9142857142857143, 0.5333333333333333, 0.6736842105263158, 0.8000000000000002, 0.8478134110787172
8712, 0.6666666666666666, 0.8181818181818182, 0.6428571428571429, 0.7200000000000001, 0.7758620689655173, 0.8045454545454546
1418, 0.7142857142857143, 0.8235294117647058, 0.7368421052631579, 0.7777777777777778, 0.8045977011494252, 0.8354129530600121
3720, 0.7380952380952381, 0.82, 0.7592592592592593, 0.7884615384615384, 0.8070866141732284, 0.7582352941176471
9510, 0.8214285714285714, 0.8936170212765957, 0.8076923076923077, 0.8484848484848485, 0.875, 0.8861414606095457
2, 0.6547619047619048, 0.9166666666666666, 0.6376811594202898, 0.752136752136752, 0.842911877394636, 0.7777777777777777
209, 0.7619047619047619, 0.9069767441860465, 0.7090909090909091, 0.7959183673469388, 0.8590308370044054, 0.8740782756664776
1000, 0.7023809523809523, 0.8222222222222222, 0.6851851851851852, 0.7474747474747475, 0.7905982905982906, 0.8051282051282052
9103, 0.75, 0.8936170212765957, 0.7241379310344828, 0.7999999999999999, 0.8536585365853658, 0.8959171937895342
7355, 0.75, 0.868421052631579, 0.673469387755102, 0.7586206896551724, 0.8208955223880596, 0.8975972540045767
......@@ -8,6 +8,9 @@ import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from pdf2image import convert_from_path
import random
# S_features_filter = [abs_S_Smin,rel_S_Smin_semi_width,rel_S_Smin_full_width,abs_S_Smax,rel_S_Smax_semi_width,rel_S_Smax_full_width,count_anomalies_S,ratio_anomalies_S,max_variation_S]
......@@ -43,17 +46,19 @@ RANDOM_SEED = 123456789
##### HYPERPARAMETERS #####
EPOCHS = 200 #350
BATCH_SIZE = 16 #32
EPOCHS = 250 #350
BATCH_SIZE = 32 #16
CRITERION = nn.BCELoss()
OPTIMIZER = torch.optim.Adam #torch.optim.SGD
LEARNING_RATE = 1e-3 #5*1e-3
GROWTH_RATE = 16 #32
DROP_RATE = 0.2 #0.5
SCHEDULER_PATIENCE = 15 #10
LEARNING_RATE = 5*1e-3 #*1e-3
GROWTH_RATE = 32 #16
DROP_RATE = 0.5 #0.2
SCHEDULER_PATIENCE = 10 #15
SCHEDULER_FACTOR = 0.1
SCHEDULER_EPS = 1e-8
PLOTS = []
input_features = 11
......@@ -105,7 +110,7 @@ class MLP(nn.Module):
return y
def prepare_data(seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def prepare_test_train_df(data_seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# Load data
df = pd.read_pickle(PICKLE_PATH)
......@@ -113,7 +118,7 @@ def prepare_data(seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray
df = df.drop(columns=S_features_filter)
# split the into training and testing sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=seed)
train_df, test_df = train_test_split(df, test_size=0.2, random_state=data_seed)
print(
f"Train alarm distribution (befor undersampling): {train_df['alarm'].value_counts()}"
......@@ -123,11 +128,17 @@ def prepare_data(seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray
min_alarm = train_df["alarm"].value_counts().min()
train_df = pd.concat(
[
train_df[train_df["alarm"] == 0].sample(min_alarm, random_state=seed),
train_df[train_df["alarm"] == 1].sample(min_alarm, random_state=seed),
train_df[train_df["alarm"] == 0].sample(min_alarm, random_state=data_seed),
train_df[train_df["alarm"] == 1].sample(min_alarm, random_state=data_seed),
]
)
print(f"Train alarm distribution: {train_df['alarm'].value_counts()}")
return train_df, test_df
def prepare_data(train_df = None, test_df = None, data_seed = RANDOM_SEED, model_seed = RANDOM_SEED):
if train_df is None and test_df is None:
train_df, test_df = prepare_test_train_df(seed = RANDOM_SEED)
X_train = train_df.drop(columns=["alarm"]).values
y_train = train_df["alarm"].values
......@@ -138,9 +149,10 @@ def prepare_data(seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray
return X_train, y_train, X_test, y_test
def train_model(X_train, y_train, X_test, y_test, seed = RANDOM_SEED):
torch.manual_seed(seed)
def train_model(X_train, y_train, X_test, y_test, data_seed = RANDOM_SEED, model_seed = RANDOM_SEED):
torch.manual_seed(model_seed)
# Setting up the data loader
train_loader = torch.utils.data.DataLoader(
......@@ -156,7 +168,7 @@ def train_model(X_train, y_train, X_test, y_test, seed = RANDOM_SEED):
shuffle=False
)
torch.manual_seed(seed)
torch.manual_seed(model_seed)
# Define model
model = MLP().to(device)
......@@ -249,7 +261,9 @@ def train_model(X_train, y_train, X_test, y_test, seed = RANDOM_SEED):
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
plt.tight_layout()
#plt.savefig(f'./plots/plot_loss_data{data_seed}_model{model_seed}.pdf')
plt.clf()
# load best model from checkpoint
# print(f"Loading model from checkpoint: mlp_{np.argmin(test_losses)}.pth")
......@@ -257,7 +271,7 @@ def train_model(X_train, y_train, X_test, y_test, seed = RANDOM_SEED):
return model
def evaluate_model(model, X_test, y_test):
def evaluate_model(model, X_test, y_test, data_seed = RANDOM_SEED, model_seed=RANDOM_SEED):
model.eval()
with torch.no_grad():
......@@ -301,18 +315,122 @@ def evaluate_model(model, X_test, y_test):
sns.heatmap(cm, annot=True, fmt="d")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
plt.tight_layout()
#plt.savefig(f'./plots/plot_cm_data{data_seed}_model{model_seed}.pdf')
plt.clf()
with open("./logs/fix-data-param-romaric.csv", "a") as f:
f.write(f'{data_seed}, {model_seed}, {accuracy}, {recall}, {precision}, {f1}, {f2}, {auc}\n')
return y_test_pred_binary
# Function to determine TP, TN, FP, FN
def determine_result(row):
if row['alarm'] == 1 and row['predicted'] == 1:
return 'TP'
elif row['alarm'] == 0 and row['predicted'] == 0:
return 'TN'
elif row['alarm'] == 0 and row['predicted'] == 1:
return 'FP'
elif row['alarm'] == 1 and row['predicted'] == 0:
return 'FN'
def plot_features(df, error_type, data_seed, model_seed):
scaler = StandardScaler()
df_normalized = pd.DataFrame(scaler.fit_transform(df), columns=df.columns)
plot = sns.boxplot(data=df_normalized)
plt.title(error_type)
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(f'./plots/plot_{error_type}_data{data_seed}_model{model_seed}.pdf')
plt.clf()
return plot
with open("log_fix_init.csv", "a") as f:
f.write(f'0, {accuracy}, {recall}, {precision}, {f1}, {f2}, {auc}\n')
def main():
for seed in [5 ** i + 3 for i in range(10)]:
X_train, y_train, X_test, y_test = prepare_data()
model = train_model(X_train, y_train, X_test, y_test, seed)
evaluate_model(model, X_test, y_test)
for model_seed in np.random.randint(0, 10000, size=20):
data_seed = RANDOM_SEED
train_df, test_df = prepare_test_train_df(data_seed)
X_train, y_train, X_test, y_test = prepare_data(train_df, test_df, data_seed, model_seed)
model = train_model(X_train, y_train, X_test, y_test, data_seed, model_seed)
y_test_pred = evaluate_model(model, X_test, y_test, data_seed, model_seed)
#X_train, y_train, X_test, y_test = prepare_data()
#model = train_model(X_train, y_train, X_test, y_test)
#evaluate_model(model, X_test, y_test)
# method to sort the predictions into CM categories
# and to visual the distribution of each feature for the FN and FP
def visu():
for data_seed in np.random.randint(0, 10000, size=10):
model_seed = RANDOM_SEED
train_df, test_df = prepare_test_train_df(data_seed)
X_train, y_train, X_test, y_test = prepare_data(train_df, test_df, data_seed, model_seed)
model = train_model(X_train, y_train, X_test, y_test, data_seed, model_seed)
y_test_pred = evaluate_model(model, X_test, y_test, data_seed, model_seed)
y_pred_series = pd.Series(y_test_pred.reshape(-1))
results = test_df.copy()
y_pred_series.index = results.index
results['predicted'] = y_pred_series.astype(bool)
# Create new column based on conditions for the confusion matrix
results['CM'] = results.apply(determine_result, axis=1)
results_FN = (results[results['CM'] == 'FN']).drop(columns=["alarm", "predicted", "CM"])
results_FP = (results[results['CM'] == 'FP']).drop(columns=["alarm", "predicted", "CM"])
img_FN = []
img_FP = []
# Convert PDF to image
img_loss = convert_from_path(f'./plots/plot_loss_data{data_seed}_model{model_seed}.pdf')[0]
img_cm = convert_from_path(f'./plots/plot_cm_data{data_seed}_model{model_seed}.pdf')[0]
if not (results_FN.empty):
plot_features(results_FN, "FN", data_seed, model_seed)
img_FN = convert_from_path(f'./plots/plot_FN_data{data_seed}_model{model_seed}.pdf')[0]
if not (results_FP.empty):
plot_features(results_FP, "FP", data_seed, model_seed)
img_FP = convert_from_path(f'./plots/plot_FP_data{data_seed}_model{model_seed}.pdf')[0]
plot_features(pd.DataFrame(X_test), "X_test", data_seed, model_seed)
img_X_test = convert_from_path(f'./plots/plot_X_test_data{data_seed}_model{model_seed}.pdf')[0]
plot_features(pd.DataFrame(X_train), "X_train", data_seed, model_seed)
img_X_train = convert_from_path(f'./plots/plot_X_train_data{data_seed}_model{model_seed}.pdf')[0]
fig, axes = plt.subplots(3, 2, figsize=(24, 12))
axes[0, 0].imshow(img_X_test)
axes[0, 1].imshow(img_X_train)
axes[1, 0].imshow(img_loss)
axes[1, 1].imshow(img_cm)
if not (results_FN.empty):
axes[2, 0].imshow(img_FN)
if not (results_FP.empty):
axes[2, 1].imshow(img_FP)
for axs in axes:
for ax in axs:
ax.axis('off')
plt.tight_layout()
plt.savefig(f"./plots/fix-data/results/results_data{data_seed}_model{model_seed}.png")
#plt.show()
plt.close()
# TODO: Get the files corresponding to those ids and plot some information
if __name__ == "__main__":
main()
#visu()
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
FILE_PATHS = ['./logs/random.csv', './logs/fix-data.csv','./logs/fix-model.csv', './logs/fix-data-param-romaric.csv', './logs/fix-model-param-romaric.csv']
columns = ['Accuracy'] #['Accuracy', 'Precision', 'Recall']
df0 = pd.read_csv(FILE_PATHS[0], header = 0)[columns]
df1 = pd.read_csv(FILE_PATHS[1], header = 0)[columns]
df2 = pd.read_csv(FILE_PATHS[2], header = 0)[columns]
df3 = pd.read_csv(FILE_PATHS[3], header = 0)[columns]
df4 = pd.read_csv(FILE_PATHS[4], header = 0)[columns]
fig, axes = plt.subplots(2, 3, figsize=(12, 10), sharey=True)
sns.boxplot(data=df0, ax = axes[0, 0])
sns.boxplot(data=df1, ax = axes[0, 1])
sns.boxplot(data=df2, ax = axes[0, 2])
sns.boxplot(data=df3, ax = axes[1, 1])
sns.boxplot(data=df4, ax = axes[1, 2])
axes[0, 0].set_title('Random Seeds')
axes[0, 1].set_title('Fix data, Random mode seeds')
axes[0, 2].set_title('Fix model, Random data seeds')
axes[1, 1].set_title('Fix data (Romaric)')
axes[1, 2].set_title('Fix model (Romaric)')
#for ax in axes:
#ax.set_xticks(rotation=90)
plt.tight_layout()
plt.savefig(f'error_distribution.pdf')
plt.show()
\ No newline at end of file
results/fix-data/results_data123456789_model1418.png

215 KiB

results/fix-data/results_data123456789_model1742.png

216 KiB

results/fix-data/results_data123456789_model197.png

223 KiB

results/fix-data/results_data123456789_model2456.png

217 KiB

results/fix-data/results_data123456789_model3355.png

218 KiB

results/fix-data/results_data123456789_model3856.png

219 KiB

results/fix-data/results_data123456789_model4043.png

211 KiB

results/fix-data/results_data123456789_model4224.png

220 KiB

results/fix-data/results_data123456789_model4955.png

218 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment