Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • c20verne/choquet-rank
  • c20verne/choquet-rank-student-project
2 results
Show changes
Commits on Source (5)
Showing
with 722 additions and 5 deletions
......@@ -174,4 +174,5 @@ $RECYCLE.BIN/
*.RData
RankLib/
svm_rank_linux64/
results/exp_*
\ No newline at end of file
results/exp_*
results/train
\ No newline at end of file
FROM ubuntu:latest
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y r-base wget gzip
RUN R -e "install.packages(c('kappalab', 'jsonlite'), repos='http://cran.rstudio.com/')"
RUN wget https://download.joachims.org/svm_rank/current/svm_rank_linux64.tar.gz
RUN mkdir choquet-rank
WORKDIR choquet-rank
RUN mkdir svm_rank_linux64
WORKDIR svm_rank_linux64
RUN mv ../../svm_rank_linux64.tar.gz .
RUN tar -xzf svm_rank_linux64.tar.gz
WORKDIR ..
RUN apt-get update && apt-get install -y maven
COPY data/ ./data
COPY results/ ./results
COPY scripts/ ./scripts
COPY src/ ./src
COPY pom.xml/ .
COPY Makefile .
RUN mkdir RankLib
COPY RankLib/RankLib-2.18.jar RankLib/RankLib-2.18.jar
RUN mvn package
ENTRYPOINT ["bash"]
\ No newline at end of file
......@@ -7,7 +7,7 @@ exp_time = results/exp_time
jar_cmd = java -cp target/choquet-rank-1.0.0-jar-with-dependencies.jar
install:
mvn package
mvn -DskipTests package
exp_passive:
rm -rf ${exp_passive}
mkdir -p ${exp_passive}/learn
......
# Instructions pour utiliser l'API
Pour pouvoir utiliser l'API, vous avez besoin dans un premier temps de suivre les instructions de la section "Experiments requirements" du fichier README.md afin de compiler le projet. Il existe alors deux points d'entrée pour l'API en utilisant la ligne de commande.
## Création des jeux d'entraînement et de test
La première étape est de séparer l'ensemble de règles en un jeu d'entraînement et un jeu de test qui serviront pour alimenter notre modèle. Nous pouvons appeler l'API à l'aide la ligne de commande suivante :
```bash
java -cp target/choquet-rank-1.0.0-jar-with-dependencies.jar io.gitlab.chaver.minimax.cli.SplitTrainingTestCli
```
Nous allons illustrer l'usage de cette API avec l'exemple suivant après avoir exécuté la commande `mkdir results/train` pour créer un dossier qui servira à stocker nos résultats :
```bash
java -cp target/choquet-rank-1.0.0-jar-with-dependencies.jar io.gitlab.chaver.minimax.cli.SplitTrainingTestCli -d results/rules/iris --train 0.26 -r results/train/iris -m phi:kruskal:yuleQ --seed 1234
```
L'usage de chaque paramètre est le suivant :
- `-d` : représente le chemin des règles à utiliser pour l'entraînement du modèle. Dans cet exemple, deux fichiers vont être lus :
- *results/rules/iris_sols.jsonl* : fichier où chaque ligne représente une règle d'association
- *results/rules/iris_prop.jsonl* : fichier qui contient d'autres informations utiles comme le nombre de transactions de la base de données
- `--train` : représente le pourcentage de règles à utiliser pour l'entraînement du modèle, c'est un réel compris entre 0 et 1. Dans cet exemple, nous utilisons 26% de règles pour entraîner le modèle et le reste pour le test.
- `-r` : représente le chemin où seront stockés le fichier d'entraînement et de test. Dans cet exemple, deux fichiers vont être créés :
- *results/train/iris_train.jsonl* : fichier qui contiendra toutes les règles utilisées pour l'entraînement du modèle
- *results/train/iris_test.jsonl* : idem avec les règles utilisées pour le test du modèle
- `-m` : représente les mesures à calculer pour chaque règle et qui seront utilisées pour apprendre le modèle dans la prochaine étape. Dans cet exemple, nous utilisons les mesures phi, kruskal et yuleQ (chaque mesure étant séparée par :). Les mesures suivantes sont disponibles (voir la classe `RuleMeasures` du package `io.gitlab.chaver.minimax.rules.io`) :
```java
// Measure names
public static final String confidence = "confidence";
public static final String lift = "lift";
public static final String cosine = "cosine";
public static final String phi = "phi";
public static final String kruskal = "kruskal";
public static final String yuleQ = "yuleQ";
public static final String addedValue = "pavillon";
public static final String certainty = "certainty";
public static final String support = "support";
public static final String revsupport = "revsup";
```
- `--seed` : une seed qui permet d'assurer la reproductibilité des expériences
## Apprentissage du modèle et évaluation
Dans une deuxième étape, nous allons apprendre un modèle à l'aide du jeu d'entraînement généré précédemment et nous allons évaluer ce dernier sur le jeu de test. Nous pouvons appeler l'API correspondante à l'aide de la ligne de commande suivante :
```bash
java -cp target/choquet-rank-1.0.0-jar-with-dependencies.jar io.gitlab.chaver.minimax.cli.LearnFunctionAndRankCli
```
Nous allons illustrer l'usage de cette API avec l'exemple suivant :
```bash
java -cp target/choquet-rank-1.0.0-jar-with-dependencies.jar io.gitlab.chaver.minimax.cli.LearnFunctionAndRankCli -d results/rules/iris -m phi:kruskal:yuleQ --seed 1234 --tt results/train/iris -o linear -l kappalab -r results/train/iris
```
L'usage de `-d`, `-m` et `--seed` est le même que précédemment (ces paramètres doivent avoir la même valeur que précédemment). L'usage des autres paramètres est le suivant :
- `--tt` : correspond au chemin des fichiers d'entraînement de test du modèle (i.e. la valeur correspondante à `-r` lors de l'appel de l'API précédente)
- `-o` : nom de l'oracle utilisé, pour l'instant les valeurs suivantes sont possibles :
- **linear** : fonction linéaire (somme pondérée)
- **owa** : Ordered Weighted Average
- `-l` : nom de l'algorithme d'apprentissage utilisé pour apprendre le modèle, les valeurs suivantes sont possibles :
- **kappalab**
- **ahp**
- **svm**
- `-r` : chemin où sera stocké les résultats de l'apprentissage, 4 fichiers vont être créés dans cet exemple :
- *results/train/iris_func.jsonl* : représente la fonction apprise du modèle
- *results/train/iris_metrics.jsonl* : représente les métriques utilisées pour évaluer la qualité du modèle
- *results/train/iris_ordered_test_rules.jsonl* : fichier avec les règles du jeu de test classées à l'aide la fonction apprise du modèle
- *results/train/iris_ordered_test_rules_oracle.jsonl* : fichier avec les règles du jeu de test classées à l'aide de l'oracle
Il est également possible de préciser des paramètres supplémentaires en fonction de l'algorithme d'apprentissage choisi. Si c'est kappalab, alors les paramètres suivants peuvent être rajoutés :
- `--delta` : un nombre réel > 0 qui représente le delta minimum entre deux alternatives (valeur par défaut : 1e-5)
- `--kadd` : un entier > 0 qui représente la k-additivité du modèle (valeur par défaut : 2)
- `--sigf` : un entier > 0 qui représente le nombre de chiffres significatifs utilisé dans l'apprentissage (valeur par défaut : 3)
Pour svm, le paramètre suivant peut être ajouté :
- `-c` : un réel > 0 qui représente le paramètre de régularisation du modèle (valeur par défaut : 0.01)
Nous allons analyser chaque fichier résultat. Commençons par `iris_func.jsonl` qui représente la fonction apprise du modèle :
```json
{
"functionType": "mobiusChoquet",
"weights": [
0,
0.4148,
0.3067,
0.3215,
0.1618,
-0.2738,
0.069
],
"kAdditivity": 2,
"nbCriteria": 3,
"timeToLearn": 0.66,
"timeOut": false,
"nbIterations": 0,
"shapleyValues": [
0.3588,
0.4221,
0.2191
],
"interactionIndices": [
[
0,
0.1618,
-0.2738
],
[
0.1618,
0,
0.069
],
[
-0.2738,
0.069,
0
]
],
"obj": [
0.0001
],
"weightLabels": [
"{}",
"{phi}",
"{kruskal}",
"{yuleQ}",
"{phi,kruskal}",
"{phi,yuleQ}",
"{kruskal,yuleQ}"
],
"measureNames": [
"phi",
"kruskal",
"yuleQ"
]
}
```
Chaque champ a le rôle suivant :
- **functionType** : Le type de fonction apprise
- **measureNames** : Le nom de chaque mesure utilisée en entrée de la fonction
- **weights** et **weightLabels** : Les poids de la fonction apprise ainsi que le label associé à chaque poids. Dans l'exemple ci-dessus, nous avons les poids suivants :
| Label | Poids |
| --------------- | ------- |
| {} | 0 |
| {phi} | 0.4148 |
| {kruskal} | 0.3067 |
| {yuleQ} | 0.3215 |
| {phi,kruskal} | 0.1618 |
| {phi,yuleQ} | -0.2738 |
| {kruskal,yuleQ} | 0.069 |
- **kAdditivity** : La k-additivité du modèle (utile pour une fonction de type mobiusChoquet)
- **nbCriteria** : Le nombre de mesures utilisé en entrée de la fonction
- **timeToLearn** : Temps (en secondes) nécessaire pour apprendre la fonction (ici 0.66 secondes)
- **timeOut** : Booléen qui indique si l'apprentissage a été réalisé avec succès dans le temps imparti (si = TRUE cela signifie que l'on a pas pu apprendre de fonction dans le temps imparti)
- **nbIterations** : Valeur pas utile ici
- **shapleyValues** : Ce champ existe uniquement si le type de la fonction est mobiusChoquet. Il représente les valeurs de Shapley pour chaque mesure. Dans cet exemple, nous avons les valeurs suivantes :
| Mesure | Valeur de Shapley |
| ------- | ----------------- |
| phi | 0.3588 |
| kruskal | 0.4221 |
| yuleQ | 0.2191 |
- **interactionIndices** : Les indices d'interaction entre les différentes mesures si le type de la fonction est mobiusChoquet. Dans cet exemple, nous avons les indices d'interaction suivants :
| | phi | kruskal | yuleQ |
| ----------- | ------- | ------- | ------- |
| **phi** | 0 | 0.1618 | -0.2738 |
| **kruskal** | 0.1618 | 0 | 0.069 |
| **yuleQ** | -0.2738 | 0.069 | 0 |
- **obj** : Pas utile ici
Le fichier `iris_metrics.jsonl` a la forme suivante :
```json
{"kendall":0.999960332152113,"rec@1%":1.0,"AP@10%":1.0,"rec@10%":1.0,"spearman":0.999920664304226,"AP@1%":1.0}
```
Il représente un ensemble de métriques utilisées pour évaluer la qualité de la fonction apprise.
Enfin, le fichier `iris_ordered_test_rules.jsonl` a la forme suivante (même forme pour le fichier de règles classées à l'aide de l'oracle `iris_ordered_test_rules_oracle.jsonl` sauf qu'il n'y a pas de score pour elles) :
```json
{"x":[13],"y":[1,10],"measureValues":{"phi":1.0,"yuleQ":1.0,"kruskal":1.0},"score":1.0}
{"x":[10],"y":[1,13],"measureValues":{"phi":1.0,"yuleQ":1.0,"kruskal":1.0},"score":1.0}
{"x":[1],"y":[10,13],"measureValues":{"phi":1.0,"yuleQ":1.0,"kruskal":1.0},"score":1.0}
{"x":[4,13],"y":[1,10],"measureValues":{"phi":0.9488738697784996,"yuleQ":0.9996303742812153,"kruskal":0.9378170749383968},"score":0.9592490587495673}
{"x":[13],"y":[1,4,10],"measureValues":{"phi":0.9488738697784999,"yuleQ":0.9996303742812153,"kruskal":0.9336012834087966},"score":0.9569830708024073}
{"x":[1],"y":[4,10,13],"measureValues":{"phi":0.9488738697784999,"yuleQ":0.9996303742812153,"kruskal":0.9336012834087966},"score":0.9569830708024073}
{"x":[15],"y":[3,12],"measureValues":{"phi":0.913059470358805,"yuleQ":0.9942253850314515,"kruskal":0.8941710652068109},"score":0.929001794156864}
{"x":[2],"y":[14],"measureValues":{"phi":0.9010118042966065,"yuleQ":0.9930542928261625,"kruskal":0.8851183844495352},"score":0.922060751191058}
```
Chaque ligne représente une règle d'association du jeu de test. Les règles sont ordonnées par score décroissant, où le score est calculé à l'aide de la fonction apprise.
## Labels du dataset Eisen
Le fichier `data/eisen.names` contient le nom de chaque item du dataset eisen, où la première ligne réprésente le nom de l'item 1, la deuxième ligne le nom de l'item 2, etc...
Vous pouvez utiliser ce fichier pour afficher les noms des items de x et y dans chaque règle de eisen. Par exemple, si vous avez la règle suivante :
```json
{"x":[1,2],"y":[12]}
```
En utilisant le fichier eisen.names, vous pourrez afficher cette règle sur l'interface graphique de l'utilisateur :
```
{GO:0005737,GO:0016787} => {PHENOT:"reduced fitness in rich medium (YPD)"}
```
\ No newline at end of file
package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle;
import io.gitlab.chaver.minimax.learn.oracle.OWAOracle;
import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle;
import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning;
import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.SVMRankLearn;
import io.gitlab.chaver.minimax.ranking.*;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.score.FunctionParameters;
import io.gitlab.chaver.minimax.score.IScoreFunction;
import io.gitlab.chaver.minimax.score.ScoreFunctionFactory;
import io.gitlab.chaver.minimax.util.RandomUtil;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.io.*;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
public class LearnFunctionAndRankCli implements Callable<Integer> {
@Option(names = "-d", description = "Path of the rules data", required = true)
private String dataPath;
@Option(names = "--tt", description = "Path of the training/test data", required = true)
private String trainingTestDataPath;
@Option(names = "-m", description = "Measures used in the function", required = true, split = ":")
private String[] measures;
@Option(names = "-o", description = "Name of the oracle", required = true)
private String oracleName;
@Option(names = "-l", description = "Learning algorithm", required = true)
private String learnAlgorithm;
@Option(names = "--seed", description = "Seed for random number generation")
private long seed = 2994274L;
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
// Kappalab parameters
@Option(names = "--delta", description = "Delta parameter (kappalab, default value : 1e-5)")
private double delta = 1e-5;
@Option(names = "--kadd", description = "k-additivity of the model (kappalab, default value : 2)")
private int kAdd = 2;
@Option(names = "--sigf", description = "Number of significant figures (kappalab, default value : 3)")
private int sigf = 3;
// SVM parameters
@Option(names = "-c", description = "Regularisation parameter (svm, default value : 0.01)")
private double regularisationParameter = 0.01;
private List<RuleWithMeasures> readRules(String path) throws IOException {
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
Gson gson = new Gson();
List<RuleWithMeasures> rules = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
rules.add(gson.fromJson(line, RuleWithMeasures.class));
}
return rules;
}
}
private void writeRules(String path, List<RuleWithMeasures> rules) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
for (RuleWithMeasures r : rules) {
writer.write(gson.toJson(r) + "\n");
}
}
}
private void writeObject(String path, Object o) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
writer.write(gson.toJson(o));
}
}
private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) {
List<IAlternative> alternatives = new ArrayList<>();
for (RuleWithMeasures r : rules) {
double[] vector = Arrays
.stream(measures)
.mapToDouble(m -> r.getMeasureValues().get(m))
.toArray();
alternatives.add(new Alternative(vector));
}
return alternatives;
}
private Comparator<IAlternative> getOracle(double[] weights) {
if (oracleName.equals(OWAOracle.TYPE)) {
return new OWAOracle(weights);
}
if (oracleName.equals(LinearFunctionOracle.TYPE)) {
return new LinearFunctionOracle(weights);
}
throw new RuntimeException("Wrong oracle type: " + oracleName);
}
private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) {
if (learnAlgorithm.equals("kappalab")) {
KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking);
kappalab.setDelta(delta);
kappalab.setKAdditivity(kAdd);
kappalab.setSigf(sigf);
return kappalab;
}
if (learnAlgorithm.equals("ahp")) {
AHPRankLearn ahp = new AHPRankLearn(expectedRanking);
return ahp;
}
if (learnAlgorithm.equals("svm")) {
SVMRankLearn svm = new SVMRankLearn(expectedRanking);
svm.setRegularisationParameter(regularisationParameter);
return svm;
}
throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm);
}
private List<RankingMetric> getRankingMetrics(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
return Arrays.asList(new KendallConcordanceCoeff(), new SpearmanRankCorrelationCoefficient(),
new RecallMetric(top1), new RecallMetric(top10), new AveragePrecision(top1), new AveragePrecision(top10));
}
private Map<String, String> getRankingMetricLabels(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
Map<String, String> labels = new HashMap<>();
labels.put(KendallConcordanceCoeff.TYPE, KendallConcordanceCoeff.TYPE);
labels.put(SpearmanRankCorrelationCoefficient.TYPE, SpearmanRankCorrelationCoefficient.TYPE);
labels.put(RecallMetric.TYPE + "@" + top1, RecallMetric.TYPE + "@1%");
labels.put(RecallMetric.TYPE + "@" + top10, RecallMetric.TYPE + "@10%");
labels.put(AveragePrecision.TYPE + "@" + top1, AveragePrecision.TYPE + "@1%");
labels.put(AveragePrecision.TYPE + "@" + top10, AveragePrecision.TYPE + "@10%");
return labels;
}
private Map<String, Double> computeRankingMetricValues(List<RankingMetric> rankingMetrics,
Map<String, String> rankingMetricLabels,
Ranking<IAlternative> actualRanking,
Ranking<IAlternative> expectedRanking) {
Map<String, Double> rankingMetricValues = new HashMap<>();
for (RankingMetric metric : rankingMetrics) {
double value = metric.compute(expectedRanking, actualRanking);
String label = rankingMetricLabels.get(metric.getName());
rankingMetricValues.put(label, value);
}
return rankingMetricValues;
}
@Override
public Integer call() throws Exception {
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
int nbMeasures = measures.length;
List<RuleWithMeasures> trainingRules = readRules(trainingTestDataPath + "_train.jsonl");
List<RuleWithMeasures> testRules = readRules(trainingTestDataPath + "_test.jsonl");
List<IAlternative> trainingAlternatives = getAlternatives(trainingRules);
List<IAlternative> testAlternatives = getAlternatives(testRules);
RandomUtil.getInstance().setSeed(seed);
double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures);
Comparator<IAlternative> oracle = getOracle(randomWeights);
Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives);
AbstractRankingLearning algo = getLearningAlgo(expectedRanking);
FunctionParameters functionParameters = algo.learn();
functionParameters.addWeightLabels(measures);
functionParameters.setMeasureNames(measures);
IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters);
Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives);
Ranking<IAlternative> expectedTestAlternativesRanking = computeRankingWithOracle(oracle, testAlternatives);
List<RankingMetric> rankingMetrics = getRankingMetrics(testRules.size());
Map<String, String> rankingMetricLabels = getRankingMetricLabels(testRules.size());
Map<String, Double> rankingMetricValues = computeRankingMetricValues(rankingMetrics, rankingMetricLabels,
actualTestAlternativesRanking, expectedTestAlternativesRanking);
List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays
.stream(actualTestAlternativesRanking.getRanking())
.mapToObj(i -> testRules.get(i))
.collect(Collectors.toCollection(ArrayList::new));
for (int i = 0; i < testRulesRankedWithLearnedFunc.size(); i++) {
int pos = actualTestAlternativesRanking.getRanking()[i];
IAlternative a = testAlternatives.get(pos);
testRulesRankedWithLearnedFunc.get(i).setScore(func.computeScore(a));
}
List<RuleWithMeasures> testRulesRankedWithOracle = Arrays
.stream(expectedTestAlternativesRanking.getRanking())
.mapToObj(i -> new RuleWithMeasures(testRules.get(i), false))
.collect(Collectors.toCollection(ArrayList::new));
writeRules(resPath + "_ordered_test_rules.jsonl", testRulesRankedWithLearnedFunc);
writeRules(resPath + "_ordered_test_rules_oracle.jsonl", testRulesRankedWithOracle);
writeObject(resPath + "_func.jsonl", functionParameters);
writeObject(resPath + "_metrics.jsonl", rankingMetricValues);
return 0;
}
public static void main(String[] args) {
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
System.exit(exitCode);
}
}
package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.normalizer.INormalizer;
import io.gitlab.chaver.minimax.rules.io.RuleMeasures;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.util.RandomUtil;
import io.gitlab.chaver.mining.rules.io.IRule;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
public class SplitTrainingTestCli implements Callable<Integer> {
@Option(names = "-d", description = "Path of the rules data", required = true)
private String dataPath;
@Option(names = "--train", description = "%% of data used for the training (value in [0,1])", required = true)
private double trainingPercentage;
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
@Option(names = "-m", description = "Measures to compute for each rule", required = true, split = ":")
private String[] measures;
@Option(names = "--seed", description = "Seed for random number generation")
private long seed = 2994274L;
@Option(names = "--smooth", description = "Smooth measures")
private double smoothCounts = 0.1d;
private <T> void writeJsonLines(List<T> l, String path) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
for (T elt : l) {
writer.write(gson.toJson(elt) + "\n");
}
}
}
@Override
public Integer call() throws Exception {
List<IRule> rules = readAssociationRules(dataPath + "_sols.jsonl");
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
List<IAlternative> alternatives = rules
.stream()
.map(r -> new Alternative(new RuleMeasures(r, nbTransactions, smoothCounts).computeMeasures(measures)))
.collect(Collectors.toCollection(ArrayList::new));
alternatives = INormalizer.tchebychefNormalize(alternatives);
RandomUtil.getInstance().setSeed(seed);
int foldSize = (int) (trainingPercentage * rules.size());
int[][] folds = RandomUtil.getInstance().kFolds(1, rules.size(), foldSize);
List<IAlternative> trainingAlternatives = getTrainingData(alternatives, folds).get(0);
List<IRule> trainingRules = getTrainingData(rules, folds).get(0);
List<IAlternative> testAlternatives = getTestData(alternatives, folds).get(0);
List<IRule> testRules = getTestData(rules, folds).get(0);
List<RuleWithMeasures> trainingRuleWithMeasures = new ArrayList<>();
for (int i = 0; i < trainingRules.size(); i++) {
RuleWithMeasures r = new RuleWithMeasures(trainingRules.get(i), trainingAlternatives.get(i), measures);
trainingRuleWithMeasures.add(r);
}
writeJsonLines(trainingRuleWithMeasures, resPath + "_train.jsonl");
List<RuleWithMeasures> testRuleWithMeasures = new ArrayList<>();
for (int i = 0; i < testRules.size(); i++) {
RuleWithMeasures r = new RuleWithMeasures(testRules.get(i), testAlternatives.get(i), measures);
testRuleWithMeasures.add(r);
}
writeJsonLines(testRuleWithMeasures, resPath + "_test.jsonl");
return 0;
}
public static void main(String[] args) {
int exitCode = new CommandLine(new SplitTrainingTestCli()).execute(args);
System.exit(exitCode);
}
}
......@@ -29,7 +29,7 @@ public class RecallMetric implements RankingMetric {
topKB.add(predictedRanking.getObjects()[predictedRanking.getRanking()[i]]);
}
topKA.retainAll(topKB);
double value = (double) topKA.size() / k;
double value = (double) topKA.size() / topKB.size();
if (value < 0 || value > 1) {
throw new RuntimeException("R@k must be between 0 and 1");
}
......
package io.gitlab.chaver.minimax.rules.io;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.mining.rules.io.IRule;
import lombok.Getter;
import lombok.Setter;
import java.util.HashMap;
import java.util.Map;
@Getter
public class RuleWithMeasures {
private int[] x;
private int[] y;
private Map<String, Double> measureValues;
private @Setter Double score;
public RuleWithMeasures(IRule r, IAlternative a, String[] measures) {
x = r.getX();
y = r.getY();
measureValues = new HashMap<>();
for (int i = 0; i < measures.length; i++) {
measureValues.put(measures[i], a.getVector()[i]);
}
}
public RuleWithMeasures(RuleWithMeasures copy, boolean withScore) {
x = copy.getX();
y = copy.getY();
measureValues = copy.getMeasureValues();
score = withScore ? copy.score : null;
}
}
......@@ -4,6 +4,14 @@ import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import static io.gitlab.chaver.minimax.util.BitSetUtil.intToBitSet;
import static io.gitlab.chaver.minimax.util.BitSetUtil.isLexBefore;
/**
* Class to save the parameters of a score function in a file
*/
......@@ -23,4 +31,68 @@ public class FunctionParameters {
private double[] shapleyValues;
private double[][] interactionIndices;
private double[] obj;
private String[] weightLabels;
private String[] measureNames;
private BitSet[] orderCapacitySets() {
int maxNbCapacitySets = (int) Math.pow(2, nbCriteria);
Set<BitSet> selectedCapacitySets = new HashSet<>();
for (int i = 0; i < maxNbCapacitySets; i++) {
BitSet capacitySet = intToBitSet(i, nbCriteria);
if (capacitySet.cardinality() <= kAdditivity) {
selectedCapacitySets.add(capacitySet);
}
}
return new ArrayList<>(selectedCapacitySets).stream().sorted((o1, o2) -> {
if (o1.cardinality() < o2.cardinality()) {
return -1;
}
if (o2.cardinality() < o1.cardinality()) {
return 1;
}
if (isLexBefore(o1, o2, nbCriteria)) {
return -1;
}
if (isLexBefore(o2, o1, nbCriteria)) {
return 1;
}
return 0;
}).toArray(BitSet[]::new);
}
private String bitSetToString(BitSet b, String[] measureNames) {
StringBuilder str = new StringBuilder("{");
boolean comma = false;
for (int i = 0; i < measureNames.length; i++) {
if (b.get(i)) {
if (comma) str.append(",");
str.append(measureNames[i]);
comma = true;
}
}
str.append("}");
return str.toString();
}
public void addWeightLabels(String[] measureNames) {
if (functionType.equals(ChoquetMobiusScoreFunction.TYPE)) {
BitSet[] capacity = orderCapacitySets();
if (capacity.length != weights.length) {
throw new RuntimeException("Capacity and weights must have the same length");
}
weightLabels = new String[capacity.length];
for (int i = 0; i < capacity.length; i++) {
weightLabels[i] = bitSetToString(capacity[i], measureNames);
}
return;
}
if (functionType.equals(LinearScoreFunction.TYPE)) {
weightLabels = new String[measureNames.length];
for (int i = 0; i < measureNames.length; i++) {
weightLabels[i] = "{" + measureNames[i] + "}";
}
return;
}
throw new RuntimeException("This function type is not implemented for labels : " + functionType);
}
}
package io.gitlab.chaver.minimax.cli;
import org.junit.jupiter.api.Test;
import picocli.CommandLine;
import java.io.File;
import static org.junit.jupiter.api.Assertions.*;
class LearnFunctionAndRankCliTest {
@Test
void test() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
new CommandLine(new SplitTrainingTestCli()).execute(args);
String oracleName = "linear";
String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm, "--tt", resPath};
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
assertEquals(0, exitCode);
}
}
\ No newline at end of file
package io.gitlab.chaver.minimax.cli;
import org.junit.jupiter.api.Test;
import picocli.CommandLine;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import static org.junit.jupiter.api.Assertions.*;
class SplitTrainingTestCliTest {
private long countLines(String filePath) throws IOException {
return Files.readAllLines(Paths.get(filePath)).stream().count();
}
@Test
void test() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
int exitCode = new CommandLine(new SplitTrainingTestCli()).execute(args);
assertEquals(0, exitCode);
assertEquals(51, countLines(resPath + "_train.jsonl"));
assertEquals(147, countLines(resPath + "_test.jsonl"));
}
}
\ No newline at end of file
......@@ -12,13 +12,13 @@ import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
class RanklibCallTest {
@Test
/*@Test
void test() throws Exception {
File trainingFile = new File("RankLib/examples/ex_passive.txt");
RanklibCall ranklibCall = new RanklibCall();
File resultModel = File.createTempFile("model", ".txt");
ranklibCall.trainModel(trainingFile, null, RanklibCall.Ranker.MART, "RR@10", resultModel, new ArrayList<>());
assertTrue(Files.readAllLines(Paths.get(resultModel.getAbsolutePath())).size() > 0);
}
}*/
}
\ No newline at end of file