package io.gitlab.chaver.minimax.cli; import com.google.gson.Gson; import io.gitlab.chaver.minimax.io.Alternative; import io.gitlab.chaver.minimax.io.IAlternative; import io.gitlab.chaver.minimax.learn.oracle.*; import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning; import io.gitlab.chaver.minimax.learn.train.passive.AHPRankLearn; import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn; import io.gitlab.chaver.minimax.learn.train.passive.SVMRankLearn; import io.gitlab.chaver.minimax.ranking.*; import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures; import io.gitlab.chaver.minimax.score.FunctionParameters; import io.gitlab.chaver.minimax.score.IScoreFunction; import io.gitlab.chaver.minimax.score.ScoreFunctionFactory; import io.gitlab.chaver.minimax.util.RandomUtil; import io.gitlab.chaver.mining.rules.io.IRule; import picocli.CommandLine; import picocli.CommandLine.Option; import java.io.*; import java.util.*; import java.util.concurrent.Callable; import java.util.stream.Collectors; import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*; public class LearnFunctionAndRankCli implements Callable<Integer> { @Option(names = "-d", description = "Path of the rules data", required = true) private String dataPath; @Option(names = "--tt", description = "Path of the training/test data", required = true) private String trainingTestDataPath; @Option(names = "-m", description = "Measures used in the function", required = true, split = ":") private String[] measures; @Option(names = "-o", description = "Name of the oracle", required = true) private String oracleName; @Option(names = "-l", description = "Learning algorithm", required = true) private String learnAlgorithm; @Option(names = "--seed", description = "Seed for random number generation") private long seed = 2994274L; @Option(names = "-r", description = "Path of the result files", required = true) private String resPath; // Kappalab parameters @Option(names = "--delta", description = "Delta parameter (kappalab, default value : 1e-5)") private double delta = 1e-5; @Option(names = "--kadd", description = "k-additivity of the model (kappalab, default value : 2)") private int kAdd = 2; @Option(names = "--sigf", description = "Number of significant figures (kappalab, default value : 3)") private int sigf = 3; // SVM parameters @Option(names = "-c", description = "Regularisation parameter (svm, default value : 0.01)") private double regularisationParameter = 0.01; private List<RuleWithMeasures> readRules(String path) throws IOException { try (BufferedReader reader = new BufferedReader(new FileReader(path))) { Gson gson = new Gson(); List<RuleWithMeasures> rules = new ArrayList<>(); String line; while ((line = reader.readLine()) != null) { rules.add(gson.fromJson(line, RuleWithMeasures.class)); } return rules; } } private void writeRules(String path, List<RuleWithMeasures> rules) throws IOException { try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) { Gson gson = new Gson(); for (RuleWithMeasures r : rules) { writer.write(gson.toJson(r) + "\n"); } } } private void writeObject(String path, Object o) throws IOException { try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) { Gson gson = new Gson(); writer.write(gson.toJson(o)); } } private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) { List<IAlternative> alternatives = new ArrayList<>(); for (RuleWithMeasures r : rules) { double[] vector = Arrays .stream(measures) .mapToDouble(m -> r.getMeasureValues().get(m)) .toArray(); alternatives.add(new Alternative(vector)); } return alternatives; } private Comparator<IAlternative> getOracle(double[] weights, List<IAlternative> trainingAlternatives, List<IAlternative> testAlternatives, List<RuleWithMeasures> trainingRules, List<RuleWithMeasures> testRules, int nbTransactions) { if (oracleName.equals(OWAOracle.TYPE)) { return new OWAOracle(weights); } if (oracleName.equals(LinearFunctionOracle.TYPE)) { return new LinearFunctionOracle(weights); } if (oracleName.equals("choquetPearson")) { return new ChoquetOracle( new CorrelationChoquetFuncBuilder(weights, trainingAlternatives.toArray(new IAlternative[0])).getCapacity() ); } if (oracleName.equals("chiSquared")) { ChiSquaredOracle2 oracle = new ChiSquaredOracle2(nbTransactions); oracle.addAlternativesRules(trainingAlternatives, trainingRules.stream().map(i -> (IRule) i).collect(Collectors.toList())); oracle.addAlternativesRules(testAlternatives, testRules.stream().map(i -> (IRule) i).collect(Collectors.toList())); return oracle; } throw new RuntimeException("Wrong oracle type: " + oracleName); } private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) { if (learnAlgorithm.equals("kappalab")) { KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking); kappalab.setDelta(delta); kappalab.setKAdditivity(kAdd); kappalab.setSigf(sigf); return kappalab; } if (learnAlgorithm.equals("ahp")) { AHPRankLearn ahp = new AHPRankLearn(expectedRanking); return ahp; } if (learnAlgorithm.equals("svm")) { SVMRankLearn svm = new SVMRankLearn(expectedRanking); svm.setRegularisationParameter(regularisationParameter); return svm; } throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm); } private List<RankingMetric> getRankingMetrics(int nbRules) { int top1 = (int) (0.01 * nbRules); if (top1 <= 0) { top1 = 1; } int top10 = (int) (0.1 * nbRules); if (top10 <= 0) { top10 = 1; } return Arrays.asList(new KendallConcordanceCoeff(), new SpearmanRankCorrelationCoefficient(), new RecallMetric(top1), new RecallMetric(top10), new AveragePrecision(top1), new AveragePrecision(top10)); } private Map<String, String> getRankingMetricLabels(int nbRules) { int top1 = (int) (0.01 * nbRules); if (top1 <= 0) { top1 = 1; } int top10 = (int) (0.1 * nbRules); if (top10 <= 0) { top10 = 1; } Map<String, String> labels = new HashMap<>(); labels.put(KendallConcordanceCoeff.TYPE, KendallConcordanceCoeff.TYPE); labels.put(SpearmanRankCorrelationCoefficient.TYPE, SpearmanRankCorrelationCoefficient.TYPE); labels.put(RecallMetric.TYPE + "@" + top1, RecallMetric.TYPE + "@1%"); labels.put(RecallMetric.TYPE + "@" + top10, RecallMetric.TYPE + "@10%"); labels.put(AveragePrecision.TYPE + "@" + top1, AveragePrecision.TYPE + "@1%"); labels.put(AveragePrecision.TYPE + "@" + top10, AveragePrecision.TYPE + "@10%"); return labels; } private Map<String, Double> computeRankingMetricValues(List<RankingMetric> rankingMetrics, Map<String, String> rankingMetricLabels, Ranking<IAlternative> actualRanking, Ranking<IAlternative> expectedRanking) { Map<String, Double> rankingMetricValues = new HashMap<>(); for (RankingMetric metric : rankingMetrics) { double value = metric.compute(expectedRanking, actualRanking); String label = rankingMetricLabels.get(metric.getName()); rankingMetricValues.put(label, value); } return rankingMetricValues; } @Override public Integer call() throws Exception { int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl"); int nbMeasures = measures.length; List<RuleWithMeasures> trainingRules = readRules(trainingTestDataPath + "_train.jsonl"); List<RuleWithMeasures> testRules = readRules(trainingTestDataPath + "_test.jsonl"); List<IAlternative> trainingAlternatives = getAlternatives(trainingRules); List<IAlternative> testAlternatives = getAlternatives(testRules); RandomUtil.getInstance().setSeed(seed); double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures); Comparator<IAlternative> oracle = getOracle(randomWeights, trainingAlternatives, testAlternatives, trainingRules, testRules, nbTransactions); Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives); AbstractRankingLearning algo = getLearningAlgo(expectedRanking); FunctionParameters functionParameters = algo.learn(); functionParameters.addWeightLabels(measures); functionParameters.setMeasureNames(measures); IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters); Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives); Ranking<IAlternative> expectedTestAlternativesRanking = computeRankingWithOracle(oracle, testAlternatives); List<RankingMetric> rankingMetrics = getRankingMetrics(testRules.size()); Map<String, String> rankingMetricLabels = getRankingMetricLabels(testRules.size()); Map<String, Double> rankingMetricValues = computeRankingMetricValues(rankingMetrics, rankingMetricLabels, actualTestAlternativesRanking, expectedTestAlternativesRanking); List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays .stream(actualTestAlternativesRanking.getRanking()) .mapToObj(i -> testRules.get(i)) .collect(Collectors.toCollection(ArrayList::new)); for (int i = 0; i < testRulesRankedWithLearnedFunc.size(); i++) { int pos = actualTestAlternativesRanking.getRanking()[i]; IAlternative a = testAlternatives.get(pos); testRulesRankedWithLearnedFunc.get(i).setScore(func.computeScore(a)); } List<RuleWithMeasures> testRulesRankedWithOracle = Arrays .stream(expectedTestAlternativesRanking.getRanking()) .mapToObj(i -> new RuleWithMeasures(testRules.get(i), false)) .collect(Collectors.toCollection(ArrayList::new)); writeRules(resPath + "_ordered_test_rules.jsonl", testRulesRankedWithLearnedFunc); writeRules(resPath + "_ordered_test_rules_oracle.jsonl", testRulesRankedWithOracle); writeObject(resPath + "_func.jsonl", functionParameters); writeObject(resPath + "_metrics.jsonl", rankingMetricValues); return 0; } public static void main(String[] args) { int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args); System.exit(exitCode); } }