Skip to content
Snippets Groups Projects
Commit 0c812f53 authored by VERNEREY Charles's avatar VERNEREY Charles
Browse files

Add new oracles

parent d352cb3f
No related branches found
No related tags found
No related merge requests found
...@@ -66,8 +66,9 @@ L'usage de `-d`, `-m` et `--seed` est le même que précédemment (ces paramètr ...@@ -66,8 +66,9 @@ L'usage de `-d`, `-m` et `--seed` est le même que précédemment (ces paramètr
- `--tt` : correspond au chemin des fichiers d'entraînement de test du modèle (i.e. la valeur correspondante à `-r` lors de l'appel de l'API précédente) - `--tt` : correspond au chemin des fichiers d'entraînement de test du modèle (i.e. la valeur correspondante à `-r` lors de l'appel de l'API précédente)
- `-o` : nom de l'oracle utilisé, pour l'instant les valeurs suivantes sont possibles : - `-o` : nom de l'oracle utilisé, pour l'instant les valeurs suivantes sont possibles :
- **linear** : fonction linéaire (somme pondérée) - **choquetPearson** : Intégrale de Choquet calculée avec les corrélations de Pearson entre les mesures
- **owa** : Ordered Weighted Average - **owa** : Ordered Weighted Average
- **chiSquared** : Chi-Squared test statistique
- `-l` : nom de l'algorithme d'apprentissage utilisé pour apprendre le modèle, les valeurs suivantes sont possibles : - `-l` : nom de l'algorithme d'apprentissage utilisé pour apprendre le modèle, les valeurs suivantes sont possibles :
- **kappalab** - **kappalab**
- **ahp** - **ahp**
......
...@@ -3,9 +3,7 @@ package io.gitlab.chaver.minimax.cli; ...@@ -3,9 +3,7 @@ package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson; import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative; import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative; import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle; import io.gitlab.chaver.minimax.learn.oracle.*;
import io.gitlab.chaver.minimax.learn.oracle.OWAOracle;
import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle;
import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning; import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning;
import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn; import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn; import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn;
...@@ -16,6 +14,7 @@ import io.gitlab.chaver.minimax.score.FunctionParameters; ...@@ -16,6 +14,7 @@ import io.gitlab.chaver.minimax.score.FunctionParameters;
import io.gitlab.chaver.minimax.score.IScoreFunction; import io.gitlab.chaver.minimax.score.IScoreFunction;
import io.gitlab.chaver.minimax.score.ScoreFunctionFactory; import io.gitlab.chaver.minimax.score.ScoreFunctionFactory;
import io.gitlab.chaver.minimax.util.RandomUtil; import io.gitlab.chaver.minimax.util.RandomUtil;
import io.gitlab.chaver.mining.rules.io.IRule;
import picocli.CommandLine; import picocli.CommandLine;
import picocli.CommandLine.Option; import picocli.CommandLine.Option;
...@@ -96,13 +95,26 @@ public class LearnFunctionAndRankCli implements Callable<Integer> { ...@@ -96,13 +95,26 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
return alternatives; return alternatives;
} }
private Comparator<IAlternative> getOracle(double[] weights) { private Comparator<IAlternative> getOracle(double[] weights, List<IAlternative> trainingAlternatives,
List<IAlternative> testAlternatives, List<RuleWithMeasures> trainingRules,
List<RuleWithMeasures> testRules, int nbTransactions) {
if (oracleName.equals(OWAOracle.TYPE)) { if (oracleName.equals(OWAOracle.TYPE)) {
return new OWAOracle(weights); return new OWAOracle(weights);
} }
if (oracleName.equals(LinearFunctionOracle.TYPE)) { if (oracleName.equals(LinearFunctionOracle.TYPE)) {
return new LinearFunctionOracle(weights); return new LinearFunctionOracle(weights);
} }
if (oracleName.equals("choquetPearson")) {
return new ChoquetOracle(
new CorrelationChoquetFuncBuilder(weights, trainingAlternatives.toArray(new IAlternative[0])).getCapacity()
);
}
if (oracleName.equals("chiSquared")) {
ChiSquaredOracle2 oracle = new ChiSquaredOracle2(nbTransactions);
oracle.addAlternativesRules(trainingAlternatives, trainingRules.stream().map(i -> (IRule) i).collect(Collectors.toList()));
oracle.addAlternativesRules(testAlternatives, testRules.stream().map(i -> (IRule) i).collect(Collectors.toList()));
return oracle;
}
throw new RuntimeException("Wrong oracle type: " + oracleName); throw new RuntimeException("Wrong oracle type: " + oracleName);
} }
...@@ -181,7 +193,7 @@ public class LearnFunctionAndRankCli implements Callable<Integer> { ...@@ -181,7 +193,7 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
List<IAlternative> testAlternatives = getAlternatives(testRules); List<IAlternative> testAlternatives = getAlternatives(testRules);
RandomUtil.getInstance().setSeed(seed); RandomUtil.getInstance().setSeed(seed);
double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures); double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures);
Comparator<IAlternative> oracle = getOracle(randomWeights); Comparator<IAlternative> oracle = getOracle(randomWeights, trainingAlternatives, testAlternatives, trainingRules, testRules, nbTransactions);
Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives); Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives);
AbstractRankingLearning algo = getLearningAlgo(expectedRanking); AbstractRankingLearning algo = getLearningAlgo(expectedRanking);
FunctionParameters functionParameters = algo.learn(); FunctionParameters functionParameters = algo.learn();
......
...@@ -22,6 +22,16 @@ public class ChiSquaredOracle2 extends ScoreOracle { ...@@ -22,6 +22,16 @@ public class ChiSquaredOracle2 extends ScoreOracle {
func = new ChiSquaredScoreFunction(nbTransactions); func = new ChiSquaredScoreFunction(nbTransactions);
} }
public ChiSquaredOracle2(int nbTransactions) {
func = new ChiSquaredScoreFunction(nbTransactions);
}
public void addAlternativesRules(List<IAlternative> alternatives, List<IRule> rules) {
for (int i = 0; i < alternatives.size(); i++) {
mapAlternativeToRule.put(alternatives.get(i), rules.get(i));
}
}
@Override @Override
public double computeScore(IAlternative a) { public double computeScore(IAlternative a) {
return func.computeScore(mapAlternativeToRule.get(a)); return func.computeScore(mapAlternativeToRule.get(a));
......
...@@ -9,16 +9,22 @@ import java.util.HashMap; ...@@ -9,16 +9,22 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
@Getter @Getter
public class RuleWithMeasures { public class RuleWithMeasures implements IRule {
private int[] x; private int[] x;
private int[] y; private int[] y;
private Map<String, Double> measureValues; private Map<String, Double> measureValues;
private @Setter Double score; private @Setter Double score;
private int freqX;
private int freqZ;
private int freqY;
public RuleWithMeasures(IRule r, IAlternative a, String[] measures) { public RuleWithMeasures(IRule r, IAlternative a, String[] measures) {
x = r.getX(); x = r.getX();
y = r.getY(); y = r.getY();
freqX = r.getFreqX();
freqY = r.getFreqY();
freqZ = r.getFreqZ();
measureValues = new HashMap<>(); measureValues = new HashMap<>();
for (int i = 0; i < measures.length; i++) { for (int i = 0; i < measures.length; i++) {
measureValues.put(measures[i], a.getVector()[i]); measureValues.put(measures[i], a.getVector()[i]);
...@@ -28,6 +34,9 @@ public class RuleWithMeasures { ...@@ -28,6 +34,9 @@ public class RuleWithMeasures {
public RuleWithMeasures(RuleWithMeasures copy, boolean withScore) { public RuleWithMeasures(RuleWithMeasures copy, boolean withScore) {
x = copy.getX(); x = copy.getX();
y = copy.getY(); y = copy.getY();
freqX = copy.getFreqX();
freqY = copy.getFreqY();
freqZ = copy.getFreqZ();
measureValues = copy.getMeasureValues(); measureValues = copy.getMeasureValues();
score = withScore ? copy.score : null; score = withScore ? copy.score : null;
} }
......
...@@ -19,7 +19,45 @@ class LearnFunctionAndRankCliTest { ...@@ -19,7 +19,45 @@ class LearnFunctionAndRankCliTest {
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures, String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed}; "--seed", seed};
new CommandLine(new SplitTrainingTestCli()).execute(args); new CommandLine(new SplitTrainingTestCli()).execute(args);
String oracleName = "linear"; String oracleName = "owa";
String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm, "--tt", resPath};
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
assertEquals(0, exitCode);
}
@Test
void test2() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
new CommandLine(new SplitTrainingTestCli()).execute(args);
String oracleName = "chiSquared";
String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm, "--tt", resPath};
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
assertEquals(0, exitCode);
}
@Test
void test3() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
new CommandLine(new SplitTrainingTestCli()).execute(args);
String oracleName = "choquetPearson";
String learningAlgorithm = "kappalab"; String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath(); String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures, args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment