From 9a33c25272200cf4a16b1bc08a89b5f87b4353db Mon Sep 17 00:00:00 2001 From: "a.croix" <croix.alexandre@gmail.com> Date: Mon, 11 Mar 2019 17:44:12 +0100 Subject: [PATCH] Implementation of Point class + Roc class + Utils class + RocCoordinates class --- java-roc.iml | 21 +- pom.xml | 210 ++++++++++++++++++ src/main/java/be/cylab/java/roc/Point.java | 42 ++++ src/main/java/be/cylab/java/roc/Roc.java | 91 ++++++++ .../be/cylab/java/roc/RocCoordinates.java | 36 +++ src/main/java/be/cylab/java/roc/Utils.java | 26 +++ 6 files changed, 422 insertions(+), 4 deletions(-) create mode 100644 pom.xml create mode 100644 src/main/java/be/cylab/java/roc/Point.java create mode 100644 src/main/java/be/cylab/java/roc/Roc.java create mode 100644 src/main/java/be/cylab/java/roc/RocCoordinates.java create mode 100644 src/main/java/be/cylab/java/roc/Utils.java diff --git a/java-roc.iml b/java-roc.iml index 8021953..cd80409 100644 --- a/java-roc.iml +++ b/java-roc.iml @@ -1,9 +1,22 @@ <?xml version="1.0" encoding="UTF-8"?> -<module type="WEB_MODULE" version="4"> - <component name="NewModuleRootManager" inherit-compiler-output="true"> - <exclude-output /> - <content url="file://$MODULE_DIR$" /> +<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" version="4"> + <component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8"> + <output url="file://$MODULE_DIR$/target/classes" /> + <output-test url="file://$MODULE_DIR$/target/test-classes" /> + <content url="file://$MODULE_DIR$"> + <sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" /> + <sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" /> + <sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" /> + <excludeFolder url="file://$MODULE_DIR$/target" /> + </content> <orderEntry type="inheritedJdk" /> <orderEntry type="sourceFolder" forTests="false" /> + <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-commons:1.3.1" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-params:5.3.1" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-engine:5.3.1" level="project" /> + <orderEntry type="library" scope="TEST" name="Maven: org.junit.platform:junit-platform-engine:1.3.1" level="project" /> </component> </module> \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..ae47e60 --- /dev/null +++ b/pom.xml @@ -0,0 +1,210 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <groupId>be.cylab</groupId> + <artifactId>java-roc</artifactId> + <version>1.0-SNAPSHOT</version> + + <name>${project.artifactId}</name> + <url>https://gitlab.cylab.be/cylab/java-roc</url> + <description>Java library to compute ROC curve and AUC</description> + + <licenses> + <license> + <name>MIT License</name> + <url>http://www.opensource.org/licenses/mit-license.php</url> + </license> + </licenses> + + <developers> + <developer> + <name>Alexandre Croix</name> + <email>alexandre.croix@rma.ac.be</email> + <organization>cylab.be</organization> + <organizationUrl>https://cylab.be</organizationUrl> + </developer> + </developers> + + <dependencies> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-api</artifactId> + <version>5.3.1</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-params</artifactId> + <version>5.3.1</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-engine</artifactId> + <version>5.3.1</version> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <plugins> + <!-- leave this one first, to be sure we use this recent version --> + <plugin> + <artifactId>maven-surefire-plugin</artifactId> + <version>2.22.1</version> + </plugin> + + <!-- to deploy to sonatype maven repository --> + <plugin> + <groupId>org.sonatype.plugins</groupId> + <artifactId>nexus-staging-maven-plugin</artifactId> + <version>1.6.8</version> + <extensions>true</extensions> + <configuration> + <serverId>ossrh</serverId> + <nexusUrl>https://oss.sonatype.org/</nexusUrl> + <autoReleaseAfterClose>true</autoReleaseAfterClose> + </configuration> + </plugin> + + <!-- to create a jar containing the sources --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-source-plugin</artifactId> + <version>3.0.1</version> + <executions> + <execution> + <id>attach-sources</id> + <goals> + <goal>jar-no-fork</goal> + </goals> + </execution> + </executions> + </plugin> + + <!-- to create a jar containing the javadoc --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-javadoc-plugin</artifactId> + <version>2.9.1</version> + <executions> + <execution> + <phase>package</phase> + <id>attach-javadocs</id> + <goals> + <goal>jar</goal> + </goals> + <configuration> + <additionalparam>${javadoc.opts}</additionalparam> + </configuration> + </execution> + </executions> + </plugin> + + <!-- To tag the release in git repository --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-release-plugin</artifactId> + <version>2.5.3</version> + <configuration> + <tagNameFormat>v@{project.version}</tagNameFormat> + </configuration> + </plugin> + + <!-- to sign the jars with a GPG key --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + <version>1.6</version> + <executions> + <execution> + <id>sign-artifacts</id> + <phase>verify</phase> + <goals> + <goal>sign</goal> + </goals> + </execution> + </executions> + </plugin> + + <!-- To check that there are no unused dependencies --> + <plugin> + <artifactId>maven-dependency-plugin</artifactId> + <version>2.10</version> + <executions> + <execution> + <id>dependency-analyze</id> + <phase>test</phase> + <configuration> + <failOnWarning>true</failOnWarning> + <ignoreNonCompile>false</ignoreNonCompile> + <ignoredDependencies> + <ignoreDependency>org.junit.jupiter:junit-jupiter-engine:*</ignoreDependency> + <ignoreDependency>org.junit.jupiter:junit-jupiter-params:*</ignoreDependency> + </ignoredDependencies> + </configuration> + <goals> + <goal>analyze-only</goal> + </goals> + </execution> + </executions> + </plugin> + + <!-- To check that we don't use -SNAPSHOT dependencies --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-enforcer-plugin</artifactId> + <version>3.0.0-M1</version> + <executions> + <execution> + <id>enforce-no-snapshots</id> + <goals> + <goal>enforce</goal> + </goals> + <configuration> + <rules> + <requireReleaseDeps> + <message>No Snapshots Allowed!</message> + </requireReleaseDeps> + </rules> + <fail>true</fail> + </configuration> + </execution> + </executions> + </plugin> + + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-checkstyle-plugin</artifactId> + <version>2.16</version> + <executions> + <execution> + <id>validate</id> + <phase>verify</phase> + <configuration> + <configLocation>checkstyle.xml</configLocation> + <encoding>UTF-8</encoding> + <consoleOutput>true</consoleOutput> + <linkXRef>false</linkXRef> + </configuration> + <goals> + <goal>check</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + + <properties> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <maven.compiler.source>1.8</maven.compiler.source> + <maven.compiler.target>1.8</maven.compiler.target> + </properties> + + + +</project> \ No newline at end of file diff --git a/src/main/java/be/cylab/java/roc/Point.java b/src/main/java/be/cylab/java/roc/Point.java new file mode 100644 index 0000000..b475816 --- /dev/null +++ b/src/main/java/be/cylab/java/roc/Point.java @@ -0,0 +1,42 @@ +package be.cylab.java.roc; + +public class Point implements Comparable<Point>{ + + private double score; + private boolean true_alert; + + public Point(final double score, final boolean true_alert) { + this.setScore(score); + this.setTrueAlert(true_alert); + } + + @Override + public final int compareTo(final Point point) { + if (this.score > point.score) { + return -1; + } else if (this.score < point.score) { + return 1; + } else { + return 0; + } + } + + public void setScore(double score) { + if (score < 0 || score > 1) { + throw new IllegalArgumentException("Score must be between 0 and 1"); + } + this.score = score; + } + + public void setTrueAlert(boolean true_alert) { + this.true_alert = true_alert; + } + + public double getScore() { + return score; + } + + public boolean isTrueAlert() { + return true_alert; + } +} diff --git a/src/main/java/be/cylab/java/roc/Roc.java b/src/main/java/be/cylab/java/roc/Roc.java new file mode 100644 index 0000000..64e2eff --- /dev/null +++ b/src/main/java/be/cylab/java/roc/Roc.java @@ -0,0 +1,91 @@ +package be.cylab.java.roc; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class Roc { + + private List<Point> points; + private int positive_examples_number; + private int negative_examples_number; + + + public Roc(double[] score, boolean[] true_alert) { + if (score.length != true_alert.length) { + throw new IllegalStateException( + "Score array and true alert array must be the same size"); + } + for (int i = 0; i < score.length; i++) { + Point point = new Point(score[i], true_alert[i]); + this.points.add(point); + } + Collections.sort(points); + positive_examples_number = Utils.countPositiveExamples(this.points); + negative_examples_number = Utils.countNegativeExamples(this.points); + } + + public List<RocCoordinates> computeRocPoints() { + int true_positive = 0; + int false_positive = 0; + List<RocCoordinates> roc_points = new ArrayList<>(); + double previous_score = Double.NEGATIVE_INFINITY; + for (Point point : this.points) { + if (point.getScore() != previous_score) { + double true_detection = true_positive / positive_examples_number; + double false_alarm = false_positive /negative_examples_number; + roc_points.add(new RocCoordinates(true_detection, false_alarm)); + previous_score = point.getScore(); + } + if (point.getScore() == 1) { + true_positive++; + } else { + false_positive++; + } + } + double true_detection = true_positive / positive_examples_number; + double false_alarm = false_positive /negative_examples_number; + roc_points.add(new RocCoordinates(true_detection, false_alarm)); + return roc_points; + } + + + public double computeAUC() { + int false_positive = 0; + int true_positive = 0; + int previous_false_positive = 0; + int previous_true_positive = 0; + double area = 0; + double previous_score = Double.NEGATIVE_INFINITY; + for (Point point : this.points) { + if (point.getScore() != previous_score) { + area += this.trapezoidArea( + false_positive, + previous_false_positive, + true_positive, + previous_true_positive); + previous_score = point.getScore(); + previous_false_positive = false_positive; + previous_true_positive = true_positive; + } + if (point.isTrueAlert()) { + true_positive++; + } else { + false_positive++; + } + area += this.trapezoidArea( + false_positive, + previous_false_positive, + true_positive, + previous_true_positive); + } + return (area / (positive_examples_number * negative_examples_number)); + } + + private double trapezoidArea(double x1, double x2, double y1, double y2) { + double base = Math.abs(x1 - x2); + double height_average = (y1 + y2) / 2; + return (base * height_average); + } + +} diff --git a/src/main/java/be/cylab/java/roc/RocCoordinates.java b/src/main/java/be/cylab/java/roc/RocCoordinates.java new file mode 100644 index 0000000..f037241 --- /dev/null +++ b/src/main/java/be/cylab/java/roc/RocCoordinates.java @@ -0,0 +1,36 @@ +package be.cylab.java.roc; + +public class RocCoordinates { + + private double false_alarm; + private double true_detection; + + public RocCoordinates(double false_alarm, double true_detection) { + this.setFalseAlarm(false_alarm); + this.setTrueDetection(true_detection); + } + + public double getFalseAlarm() { + return false_alarm; + } + + public void setFalseAlarm(double false_alarm) { + if (false_alarm < 0 || false_alarm > 1) { + throw new IllegalArgumentException("" + + "False alarm must be between 0 and 1"); + } + this.false_alarm = false_alarm; + } + + public double getTrueDetection() { + return true_detection; + } + + public void setTrueDetection(double true_detection) { + if (true_detection < 0 || true_detection > 1) { + throw new IllegalArgumentException( + "True detection value must be between 0 and 1"); + } + this.true_detection = true_detection; + } +} diff --git a/src/main/java/be/cylab/java/roc/Utils.java b/src/main/java/be/cylab/java/roc/Utils.java new file mode 100644 index 0000000..06aa542 --- /dev/null +++ b/src/main/java/be/cylab/java/roc/Utils.java @@ -0,0 +1,26 @@ +package be.cylab.java.roc; + +import java.util.List; + +public class Utils { + + public static int countPositiveExamples(final List<Point> points) { + int positive_examples = 0; + for (Point point : points) { + if (point.isTrueAlert()) { + positive_examples++; + } + } + return positive_examples; + } + + public static int countNegativeExamples(final List<Point> points) { + int negative_examples = 0; + for (Point point : points) { + if (!point.isTrueAlert()) { + negative_examples++; + } + } + return negative_examples; + } +} -- GitLab