From da2e9ff79e172f265a4bc0de4177cf3f16565280 Mon Sep 17 00:00:00 2001
From: Alex <croix.alexandre@gmail.com>
Date: Fri, 6 Dec 2019 15:04:02 +0100
Subject: [PATCH] Add dotted line for random classification on P-R and ROC
 curves

---
 src/main/java/be/cylab/java/roc/PRCurve.java | 33 +++++++++++++++-----
 src/main/java/be/cylab/java/roc/Roc.java     | 27 +++++++++++-----
 2 files changed, 46 insertions(+), 14 deletions(-)

diff --git a/src/main/java/be/cylab/java/roc/PRCurve.java b/src/main/java/be/cylab/java/roc/PRCurve.java
index d5e2b6f..26abb27 100644
--- a/src/main/java/be/cylab/java/roc/PRCurve.java
+++ b/src/main/java/be/cylab/java/roc/PRCurve.java
@@ -2,8 +2,11 @@ package be.cylab.java.roc;
 
 import com.opencsv.CSVReader;
 import org.knowm.xchart.BitmapEncoder;
-import org.knowm.xchart.QuickChart;
 import org.knowm.xchart.XYChart;
+import org.knowm.xchart.XYChartBuilder;
+import org.knowm.xchart.XYSeries;
+import org.knowm.xchart.style.lines.SeriesLines;
+import org.knowm.xchart.style.markers.SeriesMarkers;
 
 import java.io.FileInputStream;
 import java.io.FileNotFoundException;
@@ -131,16 +134,32 @@ public class PRCurve {
         List<CurveCoordinates> pr_coordinates = computePrcPoints();
         double[] precision = new double[pr_coordinates.size()];
         double[] recall = new double[pr_coordinates.size()];
+        //Create arrays to store Recall-Precision coordinates
         for (int i = 0; i < pr_coordinates.size(); i++) {
             precision[i] = pr_coordinates.get(i).getYAxis();
             recall[i] = pr_coordinates.get(i).getXAxis();
         }
-        XYChart chart = QuickChart.getChart("PR-Curve",
-                "Recall",
-                "Precision",
-                null,
-                recall,
-                precision);
+        //Chart creation
+        XYChart chart = new XYChartBuilder().width(600).height(400)
+                .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
+                .build();
+
+        //P / (P + N) computation
+        int pos_ex = Utils.countPositiveExamples(points);
+        int neg_ex = Utils.countNegativeExamples(points);
+        double y_dotted_position = pos_ex / (double) (pos_ex + neg_ex);
+        //Set up graph to visualize random classification
+        double[] dotted_x = new double[] {0, 1};
+        double[] dotted_y = new double[] {y_dotted_position, y_dotted_position};
+        XYSeries dotted = chart.addSeries("Random Classification",
+                dotted_x, dotted_y);
+        dotted.setLineStyle(SeriesLines.DASH_DASH);
+        dotted.setMarker(SeriesMarkers.NONE);
+
+        //Set up graph to visualize P-R curve
+        XYSeries pr = chart.addSeries("P-R curve", recall, precision);
+        pr.setMarker(SeriesMarkers.NONE);
+        //Save graph as PNG
         try {
             BitmapEncoder.saveBitmapWithDPI(chart,
                     filename,
diff --git a/src/main/java/be/cylab/java/roc/Roc.java b/src/main/java/be/cylab/java/roc/Roc.java
index 02faa51..a8a1de1 100644
--- a/src/main/java/be/cylab/java/roc/Roc.java
+++ b/src/main/java/be/cylab/java/roc/Roc.java
@@ -2,8 +2,11 @@ package be.cylab.java.roc;
 
 import com.opencsv.CSVReader;
 import org.knowm.xchart.BitmapEncoder;
-import org.knowm.xchart.QuickChart;
 import org.knowm.xchart.XYChart;
+import org.knowm.xchart.XYChartBuilder;
+import org.knowm.xchart.XYSeries;
+import org.knowm.xchart.style.lines.SeriesLines;
+import org.knowm.xchart.style.markers.SeriesMarkers;
 
 import java.io.FileInputStream;
 import java.io.FileNotFoundException;
@@ -141,12 +144,22 @@ public class Roc {
             true_detection[i] = roc_coordinates.get(i).getYAxis();
             false_alarm[i] = roc_coordinates.get(i).getXAxis();
         }
-        XYChart chart = QuickChart.getChart("Roc curve",
-                "False Alarm",
-                "True Detection",
-                null,
-                false_alarm,
-                true_detection);
+        //Chart creation
+        XYChart chart = new XYChartBuilder().width(600).height(400)
+                .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
+                .build();
+
+        //Set up graph to visualize random classification.
+        double[] dotted_x = new double[]{0, 1};
+        double[] dotted_y = new double[]{0, 1};
+        XYSeries dotted = chart.addSeries("Random Classification",
+                dotted_x, dotted_y);
+        dotted.setLineStyle(SeriesLines.DASH_DASH);
+        dotted.setMarker(SeriesMarkers.NONE);
+
+        //Set up graph to visualize P-R curve
+        XYSeries pr = chart.addSeries("ROC curve", false_alarm, true_detection);
+        pr.setMarker(SeriesMarkers.NONE);
         try {
             BitmapEncoder.saveBitmapWithDPI(chart,
                     filename,
-- 
GitLab