Skip to content
Snippets Groups Projects
Commit a46cbaab authored by a.croix's avatar a.croix
Browse files

Some improvements

parent 9a33c252
No related branches found
No related tags found
No related merge requests found
...@@ -18,5 +18,16 @@ ...@@ -18,5 +18,16 @@
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" />
<orderEntry type="library" name="Maven: com.owlike:genson:1.4" level="project" />
<orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.4" level="project" />
<orderEntry type="module-library">
<library name="Maven: be.cylab:java-wowa-training:0.0.4">
<CLASSES>
<root url="jar://$MODULE_DIR$/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar!/" />
</CLASSES>
<JAVADOC />
<SOURCES />
</library>
</orderEntry>
</component> </component>
</module> </module>
\ No newline at end of file
...@@ -47,8 +47,28 @@ ...@@ -47,8 +47,28 @@
<version>5.3.1</version> <version>5.3.1</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> <!-- Only for testing the last version of java-wowa-training -->
<dependency>
<groupId>com.owlike</groupId>
<artifactId>genson</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>info.debatty</groupId>
<artifactId>java-aggregation</artifactId>
<version>0.4</version>
</dependency>
<dependency>
<groupId>be.cylab</groupId>
<artifactId>java-wowa-training</artifactId>
<version>0.0.4</version>
<scope>system</scope>
<systemPath>${project.basedir}/../java-wowa-training/target/java-wowa-training-0.0.4-SNAPSHOT.jar</systemPath>
</dependency>
</dependencies>
<!-- ________________________________________________________________________________________________________________________>
<build> <build>
<plugins> <plugins>
<!-- leave this one first, to be sure we use this recent version --> <!-- leave this one first, to be sure we use this recent version -->
......
package be.cylab.java.roc;
import be.cylab.java.wowa.training.*;
import info.debatty.java.aggregation.WOWA;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
public class Main {
public static void main(String[] args) {
Logger logger = Logger.getLogger(Trainer.class.getName());
logger.setLevel(Level.INFO);
int population_size = 100;
int crossover_rate = 60;
int mutation_rate = 10;
int selection_method = TrainerParameters.SELECTION_METHOD_RWS;
int max_generation = 110;
int initialization_method = TrainerParameters.POPULATION_INITIALIZATION_RANDOM;
TrainerParameters parameters = new TrainerParameters(logger, population_size,
crossover_rate, mutation_rate, max_generation, selection_method, initialization_method);
Trainer trainer = new Trainer(parameters);
List<double[]> data = new ArrayList<double[]>();
data.add(new double[] {0.1, 0.2, 0.3, 0.4});
data.add(new double[] {0.1, 0.8, 0.3, 0.4});
data.add(new double[] {0.2, 0.6, 0.3, 0.4});
data.add(new double[] {0.1, 0.2, 0.5, 0.8});
data.add(new double[] {0.5, 0.1, 0.2, 0.3});
data.add(new double[] {0.1, 0.1, 0.1, 0.1});
//Expected aggregated value for each data vector
double[] expected = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
double[] solutions = new double[6];
SolutionDistance solution = trainer.run(data, expected);
WOWA wowa = new WOWA(solution.getWeightsW(), solution.getWeightsP());
for (int i = 0; i < data.size(); i++) {
solutions[i] = wowa.aggregate(data.get(i));
}
boolean[] true_alert = new boolean[] {false, true, false, false, true, true};
Roc roc = new Roc(solutions, true_alert);
System.out.println(roc.computeAUC());
}
}
...@@ -21,6 +21,7 @@ public class Point implements Comparable<Point>{ ...@@ -21,6 +21,7 @@ public class Point implements Comparable<Point>{
} }
} }
public void setScore(double score) { public void setScore(double score) {
if (score < 0 || score > 1) { if (score < 0 || score > 1) {
throw new IllegalArgumentException("Score must be between 0 and 1"); throw new IllegalArgumentException("Score must be between 0 and 1");
......
...@@ -12,13 +12,14 @@ public class Roc { ...@@ -12,13 +12,14 @@ public class Roc {
public Roc(double[] score, boolean[] true_alert) { public Roc(double[] score, boolean[] true_alert) {
points = new ArrayList<>();
if (score.length != true_alert.length) { if (score.length != true_alert.length) {
throw new IllegalStateException( throw new IllegalStateException(
"Score array and true alert array must be the same size"); "Score array and true alert array must be the same size");
} }
for (int i = 0; i < score.length; i++) { for (int i = 0; i < score.length; i++) {
Point point = new Point(score[i], true_alert[i]); Point point = new Point(score[i], true_alert[i]);
this.points.add(point); points.add(point);
} }
Collections.sort(points); Collections.sort(points);
positive_examples_number = Utils.countPositiveExamples(this.points); positive_examples_number = Utils.countPositiveExamples(this.points);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment