Skip to content
Snippets Groups Projects
Commit 9a33c252 authored by a.croix's avatar a.croix
Browse files

Implementation of Point class + Roc class + Utils class + RocCoordinates class

parent 29c1a1e5
No related branches found
No related tags found
No related merge requests found
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" version="4">
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
<output url="file://$MODULE_DIR$/target/classes" />
<output-test url="file://$MODULE_DIR$/target/test-classes" />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-commons:1.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" />
</component>
</module>
\ No newline at end of file
pom.xml 0 → 100644
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>be.cylab</groupId>
<artifactId>java-roc</artifactId>
<version>1.0-SNAPSHOT</version>
<name>${project.artifactId}</name>
<url>https://gitlab.cylab.be/cylab/java-roc</url>
<description>Java library to compute ROC curve and AUC</description>
<licenses>
<license>
<name>MIT License</name>
<url>http://www.opensource.org/licenses/mit-license.php</url>
</license>
</licenses>
<developers>
<developer>
<name>Alexandre Croix</name>
<email>alexandre.croix@rma.ac.be</email>
<organization>cylab.be</organization>
<organizationUrl>https://cylab.be</organizationUrl>
</developer>
</developers>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>5.3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>5.3.1</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<!-- leave this one first, to be sure we use this recent version -->
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.1</version>
</plugin>
<!-- to deploy to sonatype maven repository -->
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<version>1.6.8</version>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
<!-- to create a jar containing the sources -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.0.1</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- to create a jar containing the javadoc -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<executions>
<execution>
<phase>package</phase>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<additionalparam>${javadoc.opts}</additionalparam>
</configuration>
</execution>
</executions>
</plugin>
<!-- To tag the release in git repository -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-release-plugin</artifactId>
<version>2.5.3</version>
<configuration>
<tagNameFormat>v@{project.version}</tagNameFormat>
</configuration>
</plugin>
<!-- to sign the jars with a GPG key -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- To check that there are no unused dependencies -->
<plugin>
<artifactId>maven-dependency-plugin</artifactId>
<version>2.10</version>
<executions>
<execution>
<id>dependency-analyze</id>
<phase>test</phase>
<configuration>
<failOnWarning>true</failOnWarning>
<ignoreNonCompile>false</ignoreNonCompile>
<ignoredDependencies>
<ignoreDependency>org.junit.jupiter:junit-jupiter-engine:*</ignoreDependency>
<ignoreDependency>org.junit.jupiter:junit-jupiter-params:*</ignoreDependency>
</ignoredDependencies>
</configuration>
<goals>
<goal>analyze-only</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- To check that we don't use -SNAPSHOT dependencies -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<version>3.0.0-M1</version>
<executions>
<execution>
<id>enforce-no-snapshots</id>
<goals>
<goal>enforce</goal>
</goals>
<configuration>
<rules>
<requireReleaseDeps>
<message>No Snapshots Allowed!</message>
</requireReleaseDeps>
</rules>
<fail>true</fail>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>2.16</version>
<executions>
<execution>
<id>validate</id>
<phase>verify</phase>
<configuration>
<configLocation>checkstyle.xml</configLocation>
<encoding>UTF-8</encoding>
<consoleOutput>true</consoleOutput>
<linkXRef>false</linkXRef>
</configuration>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
</project>
\ No newline at end of file
package be.cylab.java.roc;
public class Point implements Comparable<Point>{
private double score;
private boolean true_alert;
public Point(final double score, final boolean true_alert) {
this.setScore(score);
this.setTrueAlert(true_alert);
}
@Override
public final int compareTo(final Point point) {
if (this.score > point.score) {
return -1;
} else if (this.score < point.score) {
return 1;
} else {
return 0;
}
}
public void setScore(double score) {
if (score < 0 || score > 1) {
throw new IllegalArgumentException("Score must be between 0 and 1");
}
this.score = score;
}
public void setTrueAlert(boolean true_alert) {
this.true_alert = true_alert;
}
public double getScore() {
return score;
}
public boolean isTrueAlert() {
return true_alert;
}
}
package be.cylab.java.roc;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Roc {
private List<Point> points;
private int positive_examples_number;
private int negative_examples_number;
public Roc(double[] score, boolean[] true_alert) {
if (score.length != true_alert.length) {
throw new IllegalStateException(
"Score array and true alert array must be the same size");
}
for (int i = 0; i < score.length; i++) {
Point point = new Point(score[i], true_alert[i]);
this.points.add(point);
}
Collections.sort(points);
positive_examples_number = Utils.countPositiveExamples(this.points);
negative_examples_number = Utils.countNegativeExamples(this.points);
}
public List<RocCoordinates> computeRocPoints() {
int true_positive = 0;
int false_positive = 0;
List<RocCoordinates> roc_points = new ArrayList<>();
double previous_score = Double.NEGATIVE_INFINITY;
for (Point point : this.points) {
if (point.getScore() != previous_score) {
double true_detection = true_positive / positive_examples_number;
double false_alarm = false_positive /negative_examples_number;
roc_points.add(new RocCoordinates(true_detection, false_alarm));
previous_score = point.getScore();
}
if (point.getScore() == 1) {
true_positive++;
} else {
false_positive++;
}
}
double true_detection = true_positive / positive_examples_number;
double false_alarm = false_positive /negative_examples_number;
roc_points.add(new RocCoordinates(true_detection, false_alarm));
return roc_points;
}
public double computeAUC() {
int false_positive = 0;
int true_positive = 0;
int previous_false_positive = 0;
int previous_true_positive = 0;
double area = 0;
double previous_score = Double.NEGATIVE_INFINITY;
for (Point point : this.points) {
if (point.getScore() != previous_score) {
area += this.trapezoidArea(
false_positive,
previous_false_positive,
true_positive,
previous_true_positive);
previous_score = point.getScore();
previous_false_positive = false_positive;
previous_true_positive = true_positive;
}
if (point.isTrueAlert()) {
true_positive++;
} else {
false_positive++;
}
area += this.trapezoidArea(
false_positive,
previous_false_positive,
true_positive,
previous_true_positive);
}
return (area / (positive_examples_number * negative_examples_number));
}
private double trapezoidArea(double x1, double x2, double y1, double y2) {
double base = Math.abs(x1 - x2);
double height_average = (y1 + y2) / 2;
return (base * height_average);
}
}
package be.cylab.java.roc;
public class RocCoordinates {
private double false_alarm;
private double true_detection;
public RocCoordinates(double false_alarm, double true_detection) {
this.setFalseAlarm(false_alarm);
this.setTrueDetection(true_detection);
}
public double getFalseAlarm() {
return false_alarm;
}
public void setFalseAlarm(double false_alarm) {
if (false_alarm < 0 || false_alarm > 1) {
throw new IllegalArgumentException("" +
"False alarm must be between 0 and 1");
}
this.false_alarm = false_alarm;
}
public double getTrueDetection() {
return true_detection;
}
public void setTrueDetection(double true_detection) {
if (true_detection < 0 || true_detection > 1) {
throw new IllegalArgumentException(
"True detection value must be between 0 and 1");
}
this.true_detection = true_detection;
}
}
package be.cylab.java.roc;
import java.util.List;
public class Utils {
public static int countPositiveExamples(final List<Point> points) {
int positive_examples = 0;
for (Point point : points) {
if (point.isTrueAlert()) {
positive_examples++;
}
}
return positive_examples;
}
public static int countNegativeExamples(final List<Point> points) {
int negative_examples = 0;
for (Point point : points) {
if (!point.isTrueAlert()) {
negative_examples++;
}
}
return negative_examples;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment