Newer
Older
package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle;
import io.gitlab.chaver.minimax.learn.oracle.OWAOracle;
import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle;
import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning;
import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn;
import io.gitlab.chaver.minimax.learn.train.passive.SVMRankLearn;
import io.gitlab.chaver.minimax.ranking.*;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.score.FunctionParameters;
import io.gitlab.chaver.minimax.score.IScoreFunction;
import io.gitlab.chaver.minimax.score.ScoreFunctionFactory;
import io.gitlab.chaver.minimax.util.RandomUtil;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
public class LearnFunctionAndRankCli implements Callable<Integer> {
@Option(names = "-d", description = "Path of the rules data", required = true)
private String dataPath;
@Option(names = "--tt", description = "Path of the training/test data", required = true)
private String trainingTestDataPath;
@Option(names = "-m", description = "Measures used in the function", required = true, split = ":")
private String[] measures;
@Option(names = "-o", description = "Name of the oracle", required = true)
private String oracleName;
@Option(names = "-l", description = "Learning algorithm", required = true)
private String learnAlgorithm;
@Option(names = "--seed", description = "Seed for random number generation")
private long seed = 2994274L;
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
@Option(names = "--delta", description = "Delta parameter (kappalab, default value : 1e-5)")
@Option(names = "--kadd", description = "k-additivity of the model (kappalab, default value : 2)")
@Option(names = "--sigf", description = "Number of significant figures (kappalab, default value : 3)")
@Option(names = "-c", description = "Regularisation parameter (svm, default value : 0.01)")
private List<RuleWithMeasures> readRules(String path) throws IOException {
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
Gson gson = new Gson();
List<RuleWithMeasures> rules = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
rules.add(gson.fromJson(line, RuleWithMeasures.class));
}
return rules;
}
}
private void writeRules(String path, List<RuleWithMeasures> rules) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
for (RuleWithMeasures r : rules) {
writer.write(gson.toJson(r) + "\n");
}
}
}
private void writeObject(String path, Object o) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
writer.write(gson.toJson(o));
}
}
private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) {
List<IAlternative> alternatives = new ArrayList<>();
for (RuleWithMeasures r : rules) {
double[] vector = Arrays
.stream(measures)
.mapToDouble(m -> r.getMeasureValues().get(m))
.toArray();
alternatives.add(new Alternative(vector));
}
return alternatives;
}
private Comparator<IAlternative> getOracle(double[] weights) {
if (oracleName.equals(OWAOracle.TYPE)) {
return new OWAOracle(weights);
}
if (oracleName.equals(LinearFunctionOracle.TYPE)) {
return new LinearFunctionOracle(weights);
}
throw new RuntimeException("Wrong oracle type: " + oracleName);
}
private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) {
if (learnAlgorithm.equals("kappalab")) {
KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking);
kappalab.setDelta(delta);
kappalab.setKAdditivity(kAdd);
kappalab.setSigf(sigf);
if (learnAlgorithm.equals("ahp")) {
AHPRankLearn ahp = new AHPRankLearn(expectedRanking);
return ahp;
}
if (learnAlgorithm.equals("svm")) {
SVMRankLearn svm = new SVMRankLearn(expectedRanking);
svm.setRegularisationParameter(regularisationParameter);
return svm;
}
throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm);
}
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
private List<RankingMetric> getRankingMetrics(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
return Arrays.asList(new KendallConcordanceCoeff(), new SpearmanRankCorrelationCoefficient(),
new RecallMetric(top1), new RecallMetric(top10), new AveragePrecision(top1), new AveragePrecision(top10));
}
private Map<String, String> getRankingMetricLabels(int nbRules) {
int top1 = (int) (0.01 * nbRules);
int top10 = (int) (0.1 * nbRules);
Map<String, String> labels = new HashMap<>();
labels.put(KendallConcordanceCoeff.TYPE, KendallConcordanceCoeff.TYPE);
labels.put(SpearmanRankCorrelationCoefficient.TYPE, SpearmanRankCorrelationCoefficient.TYPE);
labels.put(RecallMetric.TYPE + "@" + top1, RecallMetric.TYPE + "@1%");
labels.put(RecallMetric.TYPE + "@" + top10, RecallMetric.TYPE + "@10%");
labels.put(AveragePrecision.TYPE + "@" + top1, AveragePrecision.TYPE + "@1%");
labels.put(AveragePrecision.TYPE + "@" + top10, AveragePrecision.TYPE + "@10%");
return labels;
}
private Map<String, Double> computeRankingMetricValues(List<RankingMetric> rankingMetrics,
Map<String, String> rankingMetricLabels,
Ranking<IAlternative> actualRanking,
Ranking<IAlternative> expectedRanking) {
Map<String, Double> rankingMetricValues = new HashMap<>();
for (RankingMetric metric : rankingMetrics) {
double value = metric.compute(expectedRanking, actualRanking);
String label = rankingMetricLabels.get(metric.getName());
rankingMetricValues.put(label, value);
}
return rankingMetricValues;
}
@Override
public Integer call() throws Exception {
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
int nbMeasures = measures.length;
List<RuleWithMeasures> trainingRules = readRules(trainingTestDataPath + "_train.jsonl");
List<RuleWithMeasures> testRules = readRules(trainingTestDataPath + "_test.jsonl");
List<IAlternative> trainingAlternatives = getAlternatives(trainingRules);
List<IAlternative> testAlternatives = getAlternatives(testRules);
RandomUtil.getInstance().setSeed(seed);
double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures);
Comparator<IAlternative> oracle = getOracle(randomWeights);
Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives);
AbstractRankingLearning algo = getLearningAlgo(expectedRanking);
FunctionParameters functionParameters = algo.learn();
functionParameters.setMeasureNames(measures);
IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters);
Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives);
Ranking<IAlternative> expectedTestAlternativesRanking = computeRankingWithOracle(oracle, testAlternatives);
List<RankingMetric> rankingMetrics = getRankingMetrics(testRules.size());
Map<String, String> rankingMetricLabels = getRankingMetricLabels(testRules.size());
Map<String, Double> rankingMetricValues = computeRankingMetricValues(rankingMetrics, rankingMetricLabels,
actualTestAlternativesRanking, expectedTestAlternativesRanking);
List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays
.stream(actualTestAlternativesRanking.getRanking())
.mapToObj(i -> testRules.get(i))
.collect(Collectors.toCollection(ArrayList::new));
for (int i = 0; i < testRulesRankedWithLearnedFunc.size(); i++) {
int pos = actualTestAlternativesRanking.getRanking()[i];
IAlternative a = testAlternatives.get(pos);
testRulesRankedWithLearnedFunc.get(i).setScore(func.computeScore(a));
}
writeRules(resPath + "_ordered_test_rules.jsonl", testRulesRankedWithLearnedFunc);
writeObject(resPath + "_func.jsonl", functionParameters);
writeObject(resPath + "_metrics.jsonl", rankingMetricValues);
return 0;
}
public static void main(String[] args) {
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
System.exit(exitCode);
}
}