package io.gitlab.chaver.minimax.cli; import com.google.gson.Gson; import io.gitlab.chaver.minimax.io.Alternative; import io.gitlab.chaver.minimax.io.IAlternative; import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle; import io.gitlab.chaver.minimax.learn.oracle.OWAOracle; import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle; import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning; import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn; import io.gitlab.chaver.minimax.ranking.Ranking; import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures; import io.gitlab.chaver.minimax.score.FunctionParameters; import io.gitlab.chaver.minimax.score.IScoreFunction; import io.gitlab.chaver.minimax.score.ScoreFunctionFactory; import io.gitlab.chaver.minimax.util.RandomUtil; import picocli.CommandLine; import picocli.CommandLine.Option; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.concurrent.Callable; import java.util.stream.Collectors; import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*; public class LearnFunctionAndRankCli implements Callable<Integer> { @Option(names = "-d", description = "Path of the rules data", required = true) private String dataPath; @Option(names = "--tt", description = "Path of the training/test data", required = true) private String trainingTestDataPath; @Option(names = "-m", description = "Measures used in the function", required = true, split = ":") private String[] measures; @Option(names = "-o", description = "Name of the oracle", required = true) private String oracleName; @Option(names = "-l", description = "Learning algorithm", required = true) private String learnAlgorithm; @Option(names = "--seed", description = "Seed for random number generation") private long seed = 2994274L; @Option(names = "-r", description = "Path of the result files", required = true) private String resPath; private List<RuleWithMeasures> readRules(String path) throws IOException { try (BufferedReader reader = new BufferedReader(new FileReader(path))) { Gson gson = new Gson(); List<RuleWithMeasures> rules = new ArrayList<>(); String line; while ((line = reader.readLine()) != null) { rules.add(gson.fromJson(line, RuleWithMeasures.class)); } return rules; } } private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) { List<IAlternative> alternatives = new ArrayList<>(); for (RuleWithMeasures r : rules) { double[] vector = Arrays .stream(measures) .mapToDouble(m -> r.getMeasureValues().get(m)) .toArray(); alternatives.add(new Alternative(vector)); } return alternatives; } private Comparator<IAlternative> getOracle(double[] weights) { if (oracleName.equals(OWAOracle.TYPE)) { return new OWAOracle(weights); } if (oracleName.equals(LinearFunctionOracle.TYPE)) { return new LinearFunctionOracle(weights); } throw new RuntimeException("Wrong oracle type: " + oracleName); } private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) { if (learnAlgorithm.equals("kappalab")) { KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking); return kappalab; } throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm); } @Override public Integer call() throws Exception { int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl"); int nbMeasures = measures.length; List<RuleWithMeasures> trainingRules = readRules(trainingTestDataPath + "_train.jsonl"); List<RuleWithMeasures> testRules = readRules(trainingTestDataPath + "_test.jsonl"); List<IAlternative> trainingAlternatives = getAlternatives(trainingRules); List<IAlternative> testAlternatives = getAlternatives(testRules); RandomUtil.getInstance().setSeed(seed); double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures); Comparator<IAlternative> oracle = getOracle(randomWeights); Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives); AbstractRankingLearning algo = getLearningAlgo(expectedRanking); FunctionParameters functionParameters = algo.learn(); IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters); Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives); List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays .stream(actualTestAlternativesRanking.getRanking()) .mapToObj(i -> testRules.get(i)) .collect(Collectors.toCollection(ArrayList::new)); return 0; } public static void main(String[] args) { int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args); System.exit(exitCode); } }