Skip to content
Snippets Groups Projects
Commit 9daba7f5 authored by Boshra Ariguib's avatar Boshra Ariguib
Browse files

fixed seeds and added logs

parent 2631dc18
No related branches found
No related tags found
No related merge requests found
Run, Accuracy, Recall, Precision, F1 Score, F2 Score, AUC
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.919908466819222
0, 0.7619047619047619, 0.8695652173913043, 0.7407407407407407, 0.7999999999999999, 0.8403361344537816, 0.9176201372997711
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9181922196796339
0, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.9256292906178489
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.927345537757437
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.914187643020595
0, 0.8690476190476191, 0.8695652173913043, 0.8888888888888888, 0.8791208791208792, 0.8733624454148471, 0.9267734553775744
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9067505720823799
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9319221967963387
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9221967963386728
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.919908466819222
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9187643020594966
0, 0.7619047619047619, 0.8913043478260869, 0.7321428571428571, 0.8039215686274508, 0.8541666666666666, 0.9136155606407323
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9273455377574371
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9153318077803204
0, 0.7976190476190477, 0.8913043478260869, 0.7735849056603774, 0.8282828282828283, 0.8649789029535865, 0.9176201372997712
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9204805491990848
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9227688787185354
\ No newline at end of file
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.891304347826087
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.8775743707093822
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.9250572082379862
0, 0.7976190476190477, 0.9130434782608695, 0.7636363636363637, 0.8316831683168316, 0.8786610878661087, 0.9250572082379863
0, 0.7857142857142857, 0.9130434782608695, 0.75, 0.8235294117647057, 0.875, 0.8758581235697941
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.8981693363844393
0, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9330663615560641
0, 0.7976190476190477, 0.9347826086956522, 0.7543859649122807, 0.8349514563106796, 0.892116182572614, 0.9244851258581237
0, 0.7738095238095238, 0.5869565217391305, 1.0, 0.7397260273972603, 0.6398104265402843, 0.9067505720823799
0, 0.8095238095238095, 0.9565217391304348, 0.7586206896551724, 0.8461538461538461, 0.9090909090909092, 0.9382151029748284
Run, Accuracy, Recall, Precision, F1 Score, F2 Score, AUC
0, 0.7619047619047619, 0.9090909090909091, 0.7142857142857143, 0.8, 0.8620689655172413, 0.91875
0, 0.7738095238095238, 0.9024390243902439, 0.7115384615384616, 0.7956989247311829, 0.8564814814814816, 0.9160521837776517
0, 0.75, 0.9411764705882353, 0.7272727272727273, 0.8205128205128205, 0.8888888888888888, 0.803921568627451
0, 0.7738095238095238, 0.9148936170212766, 0.7413793103448276, 0.819047619047619, 0.8739837398373983, 0.885566417481311
0, 0.7738095238095238, 0.8076923076923077, 0.8235294117647058, 0.8155339805825242, 0.8108108108108109, 0.8311298076923077
0, 0.7023809523809523, 0.8478260869565217, 0.6842105263157895, 0.7572815533980581, 0.8091286307053941, 0.8707093821510298
0, 0.75, 0.9347826086956522, 0.7049180327868853, 0.8037383177570093, 0.8775510204081632, 0.9250572082379863
0, 0.7738095238095238, 0.8913043478260869, 0.7454545454545455, 0.8118811881188119, 0.8577405857740587, 0.908466819221968
0, 0.75, 0.9024390243902439, 0.6851851851851852, 0.7789473684210526, 0.8486238532110092, 0.8865570051049347
0, 0.8214285714285714, 0.9347826086956522, 0.7818181818181819, 0.8514851485148516, 0.8995815899581592, 0.9164759725400459
0, 0.7380952380952381, 0.88, 0.7333333333333333, 0.8, 0.8461538461538461, 0.8923529411764706
0, 0.75, 0.8775510204081632, 0.7413793103448276, 0.8037383177570093, 0.8464566929133858, 0.8967930029154519
0, 0.6547619047619048, 0.7906976744186046, 0.6296296296296297, 0.7010309278350516, 0.7522123893805309, 0.8077141236528644
0, 0.8095238095238095, 0.8823529411764706, 0.8181818181818182, 0.8490566037735848, 0.8687258687258687, 0.8835412953060012
0, 0.7738095238095238, 0.8775510204081632, 0.7678571428571429, 0.819047619047619, 0.8531746031746031, 0.8594752186588921
0, 0.7261904761904762, 0.84, 0.7368421052631579, 0.7850467289719626, 0.8171206225680934, 0.8664705882352941
0, 0.8214285714285714, 0.9230769230769231, 0.8135593220338984, 0.8648648648648649, 0.8988764044943821, 0.9122596153846154
0, 0.6547619047619048, 0.7272727272727273, 0.6530612244897959, 0.6881720430107526, 0.7111111111111111, 0.7704545454545455
......@@ -43,14 +43,14 @@ RANDOM_SEED = 123456789
##### HYPERPARAMETERS #####
EPOCHS = 250
BATCH_SIZE = 16
EPOCHS = 200 #350
BATCH_SIZE = 16 #32
CRITERION = nn.BCELoss()
OPTIMIZER = torch.optim.Adam
LEARNING_RATE = 0.01
GROWTH_RATE = 16
DROP_RATE = 0.5
SCHEDULER_PATIENCE = 10
OPTIMIZER = torch.optim.Adam #torch.optim.SGD
LEARNING_RATE = 1e-3 #5*1e-3
GROWTH_RATE = 16 #32
DROP_RATE = 0.2 #0.5
SCHEDULER_PATIENCE = 15 #10
SCHEDULER_FACTOR = 0.1
SCHEDULER_EPS = 1e-8
......@@ -105,7 +105,7 @@ class MLP(nn.Module):
return y
def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def prepare_data(seed = RANDOM_SEED) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# Load data
df = pd.read_pickle(PICKLE_PATH)
......@@ -113,7 +113,7 @@ def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
df = df.drop(columns=S_features_filter)
# split the into training and testing sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED)
train_df, test_df = train_test_split(df, test_size=0.2, random_state=seed)
print(
f"Train alarm distribution (befor undersampling): {train_df['alarm'].value_counts()}"
......@@ -123,8 +123,8 @@ def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
min_alarm = train_df["alarm"].value_counts().min()
train_df = pd.concat(
[
train_df[train_df["alarm"] == 0].sample(min_alarm),
train_df[train_df["alarm"] == 1].sample(min_alarm),
train_df[train_df["alarm"] == 0].sample(min_alarm, random_state=seed),
train_df[train_df["alarm"] == 1].sample(min_alarm, random_state=seed),
]
)
print(f"Train alarm distribution: {train_df['alarm'].value_counts()}")
......@@ -138,24 +138,26 @@ def prepare_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return X_train, y_train, X_test, y_test
def train_model(X_train, y_train, X_test, y_test):
def train_model(X_train, y_train, X_test, y_test, seed = RANDOM_SEED):
torch.manual_seed(RANDOM_SEED)
torch.manual_seed(seed)
# Setting up the data loader
train_loader = torch.utils.data.DataLoader(
list(zip(X_train, y_train)),
batch_size=BATCH_SIZE,
shuffle=True,
shuffle=True
)
# Setting up the test loader
test_loader = torch.utils.data.DataLoader(
list(zip(X_test, y_test)),
batch_size=BATCH_SIZE,
shuffle=False,
shuffle=False
)
torch.manual_seed(seed)
# Define model
model = MLP().to(device)
......@@ -301,10 +303,14 @@ def evaluate_model(model, X_test, y_test):
plt.ylabel("True")
plt.show()
with open("log_fix_init.csv", "a") as f:
f.write(f'0, {accuracy}, {recall}, {precision}, {f1}, {f2}, {auc}\n')
def main():
for seed in [5 ** i + 3 for i in range(10)]:
X_train, y_train, X_test, y_test = prepare_data()
model = train_model(X_train, y_train, X_test, y_test)
model = train_model(X_train, y_train, X_test, y_test, seed)
evaluate_model(model, X_test, y_test)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment