diff --git a/java-roc.iml b/java-roc.iml index cd804096e8c082545c520cd419911f13b3dc1c45..ecf7c30362b496a8b00856163d899411e9691511 100644 --- a/java-roc.iml +++ b/java-roc.iml @@ -18,5 +18,16 @@ <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" /> + <orderEntry type="library" name="Maven: com.owlike:genson:1.4" level="project" /> + <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.4" level="project" /> + <orderEntry type="module-library"> + <library name="Maven: be.cylab:java-wowa-training:0.0.4"> + <CLASSES> + <root url="jar://$MODULE_DIR$/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar!/" /> + </CLASSES> + <JAVADOC /> + <SOURCES /> + </library> + </orderEntry> </component> </module> \ No newline at end of file diff --git a/pom.xml b/pom.xml index ae47e6002cbc5ba4e5e12384c66a1db98725121f..502c8c930ca3d1127a8f2a40e00e04f641d67a39 100644 --- a/pom.xml +++ b/pom.xml @@ -47,8 +47,28 @@ <version>5.3.1</version> <scope>test</scope> </dependency> - </dependencies> + <!-- Only for testing the last version of java-wowa-training --> + <dependency> + <groupId>com.owlike</groupId> + <artifactId>genson</artifactId> + <version>1.4</version> + </dependency> + <dependency> + <groupId>info.debatty</groupId> + <artifactId>java-aggregation</artifactId> + <version>0.4</version> + </dependency> + + <dependency> + <groupId>be.cylab</groupId> + <artifactId>java-wowa-training</artifactId> + <version>0.0.4</version> + <scope>system</scope> + <systemPath>${project.basedir}/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar</systemPath> + </dependency> + </dependencies> + <!-- ________________________________________________________________________________________________________________________> <build> <plugins> <!-- leave this one first, to be sure we use this recent version --> diff --git a/src/main/java/be/cylab/java/roc/Main.java b/src/main/java/be/cylab/java/roc/Main.java new file mode 100644 index 0000000000000000000000000000000000000000..310eb4cb1692d3322aa1be3b54c8ca411183528e --- /dev/null +++ b/src/main/java/be/cylab/java/roc/Main.java @@ -0,0 +1,54 @@ +package be.cylab.java.roc; + +import be.cylab.java.wowa.training.*; +import info.debatty.java.aggregation.WOWA; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class Main { + + public static void main(String[] args) { + Logger logger = Logger.getLogger(Trainer.class.getName()); + logger.setLevel(Level.INFO); + int population_size = 100; + int crossover_rate = 60; + int mutation_rate = 10; + int selection_method = TrainerParameters.SELECTION_METHOD_RWS; + int max_generation = 110; + int initialization_method = TrainerParameters.POPULATION_INITIALIZATION_RANDOM; + + TrainerParameters parameters = new TrainerParameters(logger, population_size, + crossover_rate, mutation_rate, max_generation, selection_method, initialization_method); + + Trainer trainer = new Trainer(parameters); + + List<double[]> data = new ArrayList<double[]>(); + data.add(new double[] {0.1, 0.2, 0.3, 0.4}); + data.add(new double[] {0.1, 0.8, 0.3, 0.4}); + data.add(new double[] {0.2, 0.6, 0.3, 0.4}); + data.add(new double[] {0.1, 0.2, 0.5, 0.8}); + data.add(new double[] {0.5, 0.1, 0.2, 0.3}); + data.add(new double[] {0.1, 0.1, 0.1, 0.1}); + + + //Expected aggregated value for each data vector + double[] expected = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + + double[] solutions = new double[6]; + + SolutionDistance solution = trainer.run(data, expected); + WOWA wowa = new WOWA(solution.getWeightsW(), solution.getWeightsP()); + for (int i = 0; i < data.size(); i++) { + solutions[i] = wowa.aggregate(data.get(i)); + } + + boolean[] true_alert = new boolean[] {false, true, false, false, true, true}; + Roc roc = new Roc(solutions, true_alert); + System.out.println(roc.computeAUC()); + + + } +} diff --git a/src/main/java/be/cylab/java/roc/Point.java b/src/main/java/be/cylab/java/roc/Point.java index b4758163b094742a85055238c4237bafebfbed77..3dcd1dbdded78ad0de5598741eb1ad746b3aafd0 100644 --- a/src/main/java/be/cylab/java/roc/Point.java +++ b/src/main/java/be/cylab/java/roc/Point.java @@ -21,6 +21,7 @@ public class Point implements Comparable<Point>{ } } + public void setScore(double score) { if (score < 0 || score > 1) { throw new IllegalArgumentException("Score must be between 0 and 1"); diff --git a/src/main/java/be/cylab/java/roc/Roc.java b/src/main/java/be/cylab/java/roc/Roc.java index 64e2eff370333d74d27f07852aa881cd679d69a1..233421f038daa1adbaf7428e57c86c9eed17b0e5 100644 --- a/src/main/java/be/cylab/java/roc/Roc.java +++ b/src/main/java/be/cylab/java/roc/Roc.java @@ -12,13 +12,14 @@ public class Roc { public Roc(double[] score, boolean[] true_alert) { + points = new ArrayList<>(); if (score.length != true_alert.length) { throw new IllegalStateException( "Score array and true alert array must be the same size"); } for (int i = 0; i < score.length; i++) { Point point = new Point(score[i], true_alert[i]); - this.points.add(point); + points.add(point); } Collections.sort(points); positive_examples_number = Utils.countPositiveExamples(this.points);