diff --git a/clustering/mesures_clustering.py b/clustering/mesures_clustering.py index ed08e4befdb9b20760844cb36df36d32cc08f1a2..af1ffdf76328c289e7b40c06d2b9f489c73ef130 100644 --- a/clustering/mesures_clustering.py +++ b/clustering/mesures_clustering.py @@ -1,14 +1,10 @@ -def compute_silhouette_scores(reduced_embeddings, clustering_results): +def compute_silhouette_scores(reduced_embeddings, labels): """ Calcule les scores de silhouette pour différents nombres de clusters. :param reduced_embeddings: Matrice des embeddings réduits - :param clustering_results: Dictionnaire contenant les labels prédits pour chaque nombre de clusters - :return: Dictionnaire des scores de silhouette + :param labels: les labels prédits par les algos de clustering + :return: silhouette score """ - silhouette_scores = {} - for n_clusters, labels in clustering_results.items(): - silhouette_avg = silhouette_score(reduced_embeddings, labels) - silhouette_scores[n_clusters] = silhouette_avg - print(f"Nombre de clusters: {n_clusters}, Silhouette Score: {silhouette_avg:.2f}") - return silhouette_scores \ No newline at end of file + silhouette_avg = silhouette_score(reduced_embeddings, labels) + return silhouette_avg \ No newline at end of file