Skip to content
Snippets Groups Projects
Commit 714b9363 authored by VERNEREY Charles's avatar VERNEREY Charles
Browse files

Update CLI

parent 55a62acc
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,10 @@ import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle;
import io.gitlab.chaver.minimax.learn.oracle.OWAOracle;
import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle;
import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning;
import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn;
import io.gitlab.chaver.minimax.ranking.Ranking;
import io.gitlab.chaver.minimax.learn.train.passive.SVMRankLearn;
import io.gitlab.chaver.minimax.ranking.*;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.score.FunctionParameters;
import io.gitlab.chaver.minimax.score.IScoreFunction;
......@@ -17,13 +19,8 @@ import io.gitlab.chaver.minimax.util.RandomUtil;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.io.*;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
......@@ -46,6 +43,19 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
// Kappalab parameters
@Option(names = "--delta", description = "Delta parameter (kappalab)")
private double delta = 1e-5;
@Option(names = "--kadd", description = "k-additivity of the model (kappalab)")
private int kAdd = 2;
@Option(names = "--sigf", description = "Number of significant figures (kappalab)")
private int sigf = 3;
// SVM parameters
@Option(names = "-c", description = "Regularisation parameter (SVM)")
private double regularisationParameter = 0.01;
private List<RuleWithMeasures> readRules(String path) throws IOException {
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
Gson gson = new Gson();
......@@ -58,6 +68,22 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
}
}
private void writeRules(String path, List<RuleWithMeasures> rules) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
for (RuleWithMeasures r : rules) {
writer.write(gson.toJson(r) + "\n");
}
}
}
private void writeObject(String path, Object o) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
writer.write(gson.toJson(o));
}
}
private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) {
List<IAlternative> alternatives = new ArrayList<>();
for (RuleWithMeasures r : rules) {
......@@ -83,11 +109,56 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) {
if (learnAlgorithm.equals("kappalab")) {
KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking);
kappalab.setDelta(delta);
kappalab.setKAdditivity(kAdd);
kappalab.setSigf(sigf);
return kappalab;
}
if (learnAlgorithm.equals("ahp")) {
AHPRankLearn ahp = new AHPRankLearn(expectedRanking);
return ahp;
}
if (learnAlgorithm.equals("svm")) {
SVMRankLearn svm = new SVMRankLearn(expectedRanking);
svm.setRegularisationParameter(regularisationParameter);
return svm;
}
throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm);
}
private List<RankingMetric> getRankingMetrics(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
return Arrays.asList(new KendallConcordanceCoeff(), new SpearmanRankCorrelationCoefficient(),
new RecallMetric(top1), new RecallMetric(top10), new AveragePrecision(top1), new AveragePrecision(top10));
}
private Map<String, String> getRankingMetricLabels(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
Map<String, String> labels = new HashMap<>();
labels.put(KendallConcordanceCoeff.TYPE, KendallConcordanceCoeff.TYPE);
labels.put(SpearmanRankCorrelationCoefficient.TYPE, SpearmanRankCorrelationCoefficient.TYPE);
labels.put(RecallMetric.TYPE + "@" + top1, RecallMetric.TYPE + "@1%");
labels.put(RecallMetric.TYPE + "@" + top10, RecallMetric.TYPE + "@10%");
labels.put(AveragePrecision.TYPE + "@" + top1, AveragePrecision.TYPE + "@1%");
labels.put(AveragePrecision.TYPE + "@" + top10, AveragePrecision.TYPE + "@10%");
return labels;
}
private Map<String, Double> computeRankingMetricValues(List<RankingMetric> rankingMetrics,
Map<String, String> rankingMetricLabels,
Ranking<IAlternative> actualRanking,
Ranking<IAlternative> expectedRanking) {
Map<String, Double> rankingMetricValues = new HashMap<>();
for (RankingMetric metric : rankingMetrics) {
double value = metric.compute(expectedRanking, actualRanking);
String label = rankingMetricLabels.get(metric.getName());
rankingMetricValues.put(label, value);
}
return rankingMetricValues;
}
@Override
public Integer call() throws Exception {
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
......@@ -102,12 +173,26 @@ public class LearnFunctionAndRankCli implements Callable<Integer> {
Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives);
AbstractRankingLearning algo = getLearningAlgo(expectedRanking);
FunctionParameters functionParameters = algo.learn();
functionParameters.addWeightLabels(measures);
IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters);
Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives);
Ranking<IAlternative> expectedTestAlternativesRanking = computeRankingWithOracle(oracle, testAlternatives);
List<RankingMetric> rankingMetrics = getRankingMetrics(testRules.size());
Map<String, String> rankingMetricLabels = getRankingMetricLabels(testRules.size());
Map<String, Double> rankingMetricValues = computeRankingMetricValues(rankingMetrics, rankingMetricLabels,
actualTestAlternativesRanking, expectedTestAlternativesRanking);
List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays
.stream(actualTestAlternativesRanking.getRanking())
.mapToObj(i -> testRules.get(i))
.collect(Collectors.toCollection(ArrayList::new));
for (int i = 0; i < testRulesRankedWithLearnedFunc.size(); i++) {
int pos = actualTestAlternativesRanking.getRanking()[i];
IAlternative a = testAlternatives.get(pos);
testRulesRankedWithLearnedFunc.get(i).setScore(func.computeScore(a));
}
writeRules(resPath + "_rules.jsonl", testRulesRankedWithLearnedFunc);
writeObject(resPath + "_func.jsonl", functionParameters);
writeObject(resPath + "_metrics.jsonl", rankingMetricValues);
return 0;
}
......
......@@ -29,7 +29,7 @@ public class RecallMetric implements RankingMetric {
topKB.add(predictedRanking.getObjects()[predictedRanking.getRanking()[i]]);
}
topKA.retainAll(topKB);
double value = (double) topKA.size() / k;
double value = (double) topKA.size() / topKB.size();
if (value < 0 || value > 1) {
throw new RuntimeException("R@k must be between 0 and 1");
}
......
......@@ -4,6 +4,14 @@ import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import static io.gitlab.chaver.minimax.util.BitSetUtil.intToBitSet;
import static io.gitlab.chaver.minimax.util.BitSetUtil.isLexBefore;
/**
* Class to save the parameters of a score function in a file
*/
......@@ -23,4 +31,67 @@ public class FunctionParameters {
private double[] shapleyValues;
private double[][] interactionIndices;
private double[] obj;
private String[] weightLabels;
private BitSet[] orderCapacitySets() {
int maxNbCapacitySets = (int) Math.pow(2, nbCriteria);
Set<BitSet> selectedCapacitySets = new HashSet<>();
for (int i = 0; i < maxNbCapacitySets; i++) {
BitSet capacitySet = intToBitSet(i, nbCriteria);
if (capacitySet.cardinality() <= kAdditivity) {
selectedCapacitySets.add(capacitySet);
}
}
return new ArrayList<>(selectedCapacitySets).stream().sorted((o1, o2) -> {
if (o1.cardinality() < o2.cardinality()) {
return -1;
}
if (o2.cardinality() < o1.cardinality()) {
return 1;
}
if (isLexBefore(o1, o2, nbCriteria)) {
return -1;
}
if (isLexBefore(o2, o1, nbCriteria)) {
return 1;
}
return 0;
}).toArray(BitSet[]::new);
}
private String bitSetToString(BitSet b, String[] measureNames) {
StringBuilder str = new StringBuilder("{");
boolean comma = false;
for (int i = 0; i < measureNames.length; i++) {
if (b.get(i)) {
if (comma) str.append(",");
str.append(measureNames[i]);
comma = true;
}
}
str.append("}");
return str.toString();
}
public void addWeightLabels(String[] measureNames) {
if (functionType.equals(ChoquetMobiusScoreFunction.TYPE)) {
BitSet[] capacity = orderCapacitySets();
if (capacity.length != weights.length) {
throw new RuntimeException("Capacity and weights must have the same length");
}
weightLabels = new String[capacity.length];
for (int i = 0; i < capacity.length; i++) {
weightLabels[i] = bitSetToString(capacity[i], measureNames);
}
return;
}
if (functionType.equals(LinearScoreFunction.TYPE)) {
weightLabels = new String[measureNames.length];
for (int i = 0; i < measureNames.length; i++) {
weightLabels[i] = "{" + measureNames[i] + "}";
}
return;
}
throw new RuntimeException("This function type is not implemented for labels : " + functionType);
}
}
......@@ -23,7 +23,9 @@ class LearnFunctionAndRankCliTest {
String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm};
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm, "--tt", resPath};
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
assertEquals(0, exitCode);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment