Skip to content
Snippets Groups Projects
Commit 55a62acc authored by VERNEREY Charles's avatar VERNEREY Charles
Browse files

Update Command Line Interface

parent 1e8acaa6
Branches
Tags
No related merge requests found
package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.learn.oracle.LinearFunctionOracle;
import io.gitlab.chaver.minimax.learn.oracle.OWAOracle;
import io.gitlab.chaver.minimax.learn.oracle.ScoreFunctionOracle;
import io.gitlab.chaver.minimax.learn.train.AbstractRankingLearning;
import io.gitlab.chaver.minimax.learn.train.passive.KappalabRankLearn;
import io.gitlab.chaver.minimax.ranking.Ranking;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.score.FunctionParameters;
import io.gitlab.chaver.minimax.score.IScoreFunction;
import io.gitlab.chaver.minimax.score.ScoreFunctionFactory;
import io.gitlab.chaver.minimax.util.RandomUtil;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
public class LearnFunctionAndRankCli implements Callable<Integer> {
@Option(names = "-d", description = "Path of the rules data", required = true)
private String dataPath;
@Option(names = "--tt", description = "Path of the training/test data", required = true)
private String trainingTestDataPath;
@Option(names = "-m", description = "Measures used in the function", required = true, split = ":")
private String[] measures;
@Option(names = "-o", description = "Name of the oracle", required = true)
private String oracleName;
@Option(names = "-l", description = "Learning algorithm", required = true)
private String learnAlgorithm;
@Option(names = "--seed", description = "Seed for random number generation")
private long seed = 2994274L;
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
private List<RuleWithMeasures> readRules(String path) throws IOException {
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
Gson gson = new Gson();
List<RuleWithMeasures> rules = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
rules.add(gson.fromJson(line, RuleWithMeasures.class));
}
return rules;
}
}
private List<IAlternative> getAlternatives(List<RuleWithMeasures> rules) {
List<IAlternative> alternatives = new ArrayList<>();
for (RuleWithMeasures r : rules) {
double[] vector = Arrays
.stream(measures)
.mapToDouble(m -> r.getMeasureValues().get(m))
.toArray();
alternatives.add(new Alternative(vector));
}
return alternatives;
}
private Comparator<IAlternative> getOracle(double[] weights) {
if (oracleName.equals(OWAOracle.TYPE)) {
return new OWAOracle(weights);
}
if (oracleName.equals(LinearFunctionOracle.TYPE)) {
return new LinearFunctionOracle(weights);
}
throw new RuntimeException("Wrong oracle type: " + oracleName);
}
private AbstractRankingLearning getLearningAlgo(Ranking<IAlternative> expectedRanking) {
if (learnAlgorithm.equals("kappalab")) {
KappalabRankLearn kappalab = new KappalabRankLearn(expectedRanking);
return kappalab;
}
throw new RuntimeException("Wrong learning algorithm: " + learnAlgorithm);
}
@Override
public Integer call() throws Exception {
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
int nbMeasures = measures.length;
List<RuleWithMeasures> trainingRules = readRules(trainingTestDataPath + "_train.jsonl");
List<RuleWithMeasures> testRules = readRules(trainingTestDataPath + "_test.jsonl");
List<IAlternative> trainingAlternatives = getAlternatives(trainingRules);
List<IAlternative> testAlternatives = getAlternatives(testRules);
RandomUtil.getInstance().setSeed(seed);
double[] randomWeights = RandomUtil.getInstance().generateRandomWeights(nbMeasures);
Comparator<IAlternative> oracle = getOracle(randomWeights);
Ranking<IAlternative> expectedRanking = computeRankingWithOracle(oracle, trainingAlternatives);
AbstractRankingLearning algo = getLearningAlgo(expectedRanking);
FunctionParameters functionParameters = algo.learn();
IScoreFunction<IAlternative> func = ScoreFunctionFactory.getScoreFunction(functionParameters);
Ranking<IAlternative> actualTestAlternativesRanking = computeRankingWithOracle(new ScoreFunctionOracle(func), testAlternatives);
List<RuleWithMeasures> testRulesRankedWithLearnedFunc = Arrays
.stream(actualTestAlternativesRanking.getRanking())
.mapToObj(i -> testRules.get(i))
.collect(Collectors.toCollection(ArrayList::new));
return 0;
}
public static void main(String[] args) {
int exitCode = new CommandLine(new LearnFunctionAndRankCli()).execute(args);
System.exit(exitCode);
}
}
package io.gitlab.chaver.minimax.cli;
import com.google.gson.Gson;
import io.gitlab.chaver.minimax.io.Alternative;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.minimax.normalizer.INormalizer;
import io.gitlab.chaver.minimax.rules.io.RuleMeasures;
import io.gitlab.chaver.minimax.rules.io.RuleWithMeasures;
import io.gitlab.chaver.minimax.util.RandomUtil;
import io.gitlab.chaver.mining.rules.io.IRule;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import static io.gitlab.chaver.minimax.learn.train.LearnUtil.*;
public class SplitTrainingTestCli implements Callable<Integer> {
@Option(names = "-d", description = "Path of the rules data", required = true)
private String dataPath;
@Option(names = "--train", description = "% of data used for the training (value in [0,1])", required = true)
private double trainingPercentage;
@Option(names = "-r", description = "Path of the result files", required = true)
private String resPath;
@Option(names = "-m", description = "Measures to compute for each rule", required = true, split = ":")
private String[] measures;
@Option(names = "--seed", description = "Seed for random number generation")
private long seed = 2994274L;
@Option(names = "--smooth", description = "Smooth measures")
private double smoothCounts = 0.1d;
private <T> void writeJsonLines(List<T> l, String path) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(path))) {
Gson gson = new Gson();
for (T elt : l) {
writer.write(gson.toJson(elt) + "\n");
}
}
}
@Override
public Integer call() throws Exception {
List<IRule> rules = readAssociationRules(dataPath + "_sols.jsonl");
int nbTransactions = getNbTransactions(dataPath + "_prop.jsonl");
List<IAlternative> alternatives = rules
.stream()
.map(r -> new Alternative(new RuleMeasures(r, nbTransactions, smoothCounts).computeMeasures(measures)))
.collect(Collectors.toCollection(ArrayList::new));
alternatives = INormalizer.tchebychefNormalize(alternatives);
RandomUtil.getInstance().setSeed(seed);
int foldSize = (int) (trainingPercentage * rules.size());
int[][] folds = RandomUtil.getInstance().kFolds(1, rules.size(), foldSize);
List<IAlternative> trainingAlternatives = getTrainingData(alternatives, folds).get(0);
List<IRule> trainingRules = getTrainingData(rules, folds).get(0);
List<IAlternative> testAlternatives = getTestData(alternatives, folds).get(0);
List<IRule> testRules = getTestData(rules, folds).get(0);
List<RuleWithMeasures> trainingRuleWithMeasures = new ArrayList<>();
for (int i = 0; i < trainingRules.size(); i++) {
RuleWithMeasures r = new RuleWithMeasures(trainingRules.get(i), trainingAlternatives.get(i), measures);
trainingRuleWithMeasures.add(r);
}
writeJsonLines(trainingRuleWithMeasures, resPath + "_train.jsonl");
List<RuleWithMeasures> testRuleWithMeasures = new ArrayList<>();
for (int i = 0; i < testRules.size(); i++) {
RuleWithMeasures r = new RuleWithMeasures(testRules.get(i), testAlternatives.get(i), measures);
testRuleWithMeasures.add(r);
}
writeJsonLines(testRuleWithMeasures, resPath + "_test.jsonl");
return 0;
}
public static void main(String[] args) {
int exitCode = new CommandLine(new SplitTrainingTestCli()).execute(args);
System.exit(exitCode);
}
}
package io.gitlab.chaver.minimax.rules.io;
import io.gitlab.chaver.minimax.io.IAlternative;
import io.gitlab.chaver.mining.rules.io.IRule;
import lombok.Getter;
import lombok.Setter;
import java.util.HashMap;
import java.util.Map;
@Getter
public class RuleWithMeasures {
private int[] x;
private int[] y;
private Map<String, Double> measureValues;
private @Setter Double score;
public RuleWithMeasures(IRule r, IAlternative a, String[] measures) {
x = r.getX();
y = r.getY();
measureValues = new HashMap<>();
for (int i = 0; i < measures.length; i++) {
measureValues.put(measures[i], a.getVector()[i]);
}
}
}
package io.gitlab.chaver.minimax.cli;
import org.junit.jupiter.api.Test;
import picocli.CommandLine;
import java.io.File;
import static org.junit.jupiter.api.Assertions.*;
class LearnFunctionAndRankCliTest {
@Test
void test() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
new CommandLine(new SplitTrainingTestCli()).execute(args);
String oracleName = "linear";
String learningAlgorithm = "kappalab";
String learnResPath = File.createTempFile("learn", "").getAbsolutePath();
args = new String[]{"-d", rulesPath, "-r", learnResPath, "-m", measures,
"--seed", seed, "-o", oracleName, "-l", learningAlgorithm};
}
}
\ No newline at end of file
package io.gitlab.chaver.minimax.cli;
import org.junit.jupiter.api.Test;
import picocli.CommandLine;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import static org.junit.jupiter.api.Assertions.*;
class SplitTrainingTestCliTest {
private long countLines(String filePath) throws IOException {
return Files.readAllLines(Paths.get(filePath)).stream().count();
}
@Test
void test() throws Exception {
String rulesPath = "results/rules/iris";
String trainingPercentage = "0.26";
String resPath = File.createTempFile("rules", "").getAbsolutePath();
String measures = "phi:kruskal:yuleQ";
String seed = "1234";
String[] args = {"-d", rulesPath, "--train", trainingPercentage, "-r", resPath, "-m", measures,
"--seed", seed};
int exitCode = new CommandLine(new SplitTrainingTestCli()).execute(args);
assertEquals(0, exitCode);
assertEquals(51, countLines(resPath + "_train.jsonl"));
assertEquals(147, countLines(resPath + "_test.jsonl"));
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment