From a46cbaabab3a9154a9af4623cc8a5dc6b08cc693 Mon Sep 17 00:00:00 2001 From: "a.croix" <croix.alexandre@gmail.com> Date: Tue, 12 Mar 2019 17:50:16 +0100 Subject: [PATCH] Some improvements --- java-roc.iml | 11 +++++ pom.xml | 22 ++++++++- src/main/java/be/cylab/java/roc/Main.java | 54 ++++++++++++++++++++++ src/main/java/be/cylab/java/roc/Point.java | 1 + src/main/java/be/cylab/java/roc/Roc.java | 3 +- 5 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 src/main/java/be/cylab/java/roc/Main.java diff --git a/java-roc.iml b/java-roc.iml index cd80409..ecf7c30 100644 --- a/java-roc.iml +++ b/java-roc.iml @@ -18,5 +18,16 @@ <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" /> + <orderEntry type="library" name="Maven: com.owlike:genson:1.4" level="project" /> + <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.4" level="project" /> + <orderEntry type="module-library"> + <library name="Maven: be.cylab:java-wowa-training:0.0.4"> + <CLASSES> + <root url="jar://$MODULE_DIR$/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar!/" /> + </CLASSES> + <JAVADOC /> + <SOURCES /> + </library> + </orderEntry> </component> </module> \ No newline at end of file diff --git a/pom.xml b/pom.xml index ae47e60..502c8c9 100644 --- a/pom.xml +++ b/pom.xml @@ -47,8 +47,28 @@ <version>5.3.1</version> <scope>test</scope> </dependency> - </dependencies> + <!-- Only for testing the last version of java-wowa-training --> + <dependency> + <groupId>com.owlike</groupId> + <artifactId>genson</artifactId> + <version>1.4</version> + </dependency> + <dependency> + <groupId>info.debatty</groupId> + <artifactId>java-aggregation</artifactId> + <version>0.4</version> + </dependency> + + <dependency> + <groupId>be.cylab</groupId> + <artifactId>java-wowa-training</artifactId> + <version>0.0.4</version> + <scope>system</scope> + <systemPath>${project.basedir}/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar</systemPath> + </dependency> + </dependencies> + <!-- ________________________________________________________________________________________________________________________> <build> <plugins> <!-- leave this one first, to be sure we use this recent version --> diff --git a/src/main/java/be/cylab/java/roc/Main.java b/src/main/java/be/cylab/java/roc/Main.java new file mode 100644 index 0000000..310eb4c --- /dev/null +++ b/src/main/java/be/cylab/java/roc/Main.java @@ -0,0 +1,54 @@ +package be.cylab.java.roc; + +import be.cylab.java.wowa.training.*; +import info.debatty.java.aggregation.WOWA; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class Main { + + public static void main(String[] args) { + Logger logger = Logger.getLogger(Trainer.class.getName()); + logger.setLevel(Level.INFO); + int population_size = 100; + int crossover_rate = 60; + int mutation_rate = 10; + int selection_method = TrainerParameters.SELECTION_METHOD_RWS; + int max_generation = 110; + int initialization_method = TrainerParameters.POPULATION_INITIALIZATION_RANDOM; + + TrainerParameters parameters = new TrainerParameters(logger, population_size, + crossover_rate, mutation_rate, max_generation, selection_method, initialization_method); + + Trainer trainer = new Trainer(parameters); + + List<double[]> data = new ArrayList<double[]>(); + data.add(new double[] {0.1, 0.2, 0.3, 0.4}); + data.add(new double[] {0.1, 0.8, 0.3, 0.4}); + data.add(new double[] {0.2, 0.6, 0.3, 0.4}); + data.add(new double[] {0.1, 0.2, 0.5, 0.8}); + data.add(new double[] {0.5, 0.1, 0.2, 0.3}); + data.add(new double[] {0.1, 0.1, 0.1, 0.1}); + + + //Expected aggregated value for each data vector + double[] expected = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + + double[] solutions = new double[6]; + + SolutionDistance solution = trainer.run(data, expected); + WOWA wowa = new WOWA(solution.getWeightsW(), solution.getWeightsP()); + for (int i = 0; i < data.size(); i++) { + solutions[i] = wowa.aggregate(data.get(i)); + } + + boolean[] true_alert = new boolean[] {false, true, false, false, true, true}; + Roc roc = new Roc(solutions, true_alert); + System.out.println(roc.computeAUC()); + + + } +} diff --git a/src/main/java/be/cylab/java/roc/Point.java b/src/main/java/be/cylab/java/roc/Point.java index b475816..3dcd1db 100644 --- a/src/main/java/be/cylab/java/roc/Point.java +++ b/src/main/java/be/cylab/java/roc/Point.java @@ -21,6 +21,7 @@ public class Point implements Comparable<Point>{ } } + public void setScore(double score) { if (score < 0 || score > 1) { throw new IllegalArgumentException("Score must be between 0 and 1"); diff --git a/src/main/java/be/cylab/java/roc/Roc.java b/src/main/java/be/cylab/java/roc/Roc.java index 64e2eff..233421f 100644 --- a/src/main/java/be/cylab/java/roc/Roc.java +++ b/src/main/java/be/cylab/java/roc/Roc.java @@ -12,13 +12,14 @@ public class Roc { public Roc(double[] score, boolean[] true_alert) { + points = new ArrayList<>(); if (score.length != true_alert.length) { throw new IllegalStateException( "Score array and true alert array must be the same size"); } for (int i = 0; i < score.length; i++) { Point point = new Point(score[i], true_alert[i]); - this.points.add(point); + points.add(point); } Collections.sort(points); positive_examples_number = Utils.countPositiveExamples(this.points); -- GitLab