From da2e9ff79e172f265a4bc0de4177cf3f16565280 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Fri, 6 Dec 2019 15:04:02 +0100 Subject: [PATCH] Add dotted line for random classification on P-R and ROC curves --- src/main/java/be/cylab/java/roc/PRCurve.java | 33 +++++++++++++++----- src/main/java/be/cylab/java/roc/Roc.java | 27 +++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/main/java/be/cylab/java/roc/PRCurve.java b/src/main/java/be/cylab/java/roc/PRCurve.java index d5e2b6f..26abb27 100644 --- a/src/main/java/be/cylab/java/roc/PRCurve.java +++ b/src/main/java/be/cylab/java/roc/PRCurve.java @@ -2,8 +2,11 @@ package be.cylab.java.roc; import com.opencsv.CSVReader; import org.knowm.xchart.BitmapEncoder; -import org.knowm.xchart.QuickChart; import org.knowm.xchart.XYChart; +import org.knowm.xchart.XYChartBuilder; +import org.knowm.xchart.XYSeries; +import org.knowm.xchart.style.lines.SeriesLines; +import org.knowm.xchart.style.markers.SeriesMarkers; import java.io.FileInputStream; import java.io.FileNotFoundException; @@ -131,16 +134,32 @@ public class PRCurve { List<CurveCoordinates> pr_coordinates = computePrcPoints(); double[] precision = new double[pr_coordinates.size()]; double[] recall = new double[pr_coordinates.size()]; + //Create arrays to store Recall-Precision coordinates for (int i = 0; i < pr_coordinates.size(); i++) { precision[i] = pr_coordinates.get(i).getYAxis(); recall[i] = pr_coordinates.get(i).getXAxis(); } - XYChart chart = QuickChart.getChart("PR-Curve", - "Recall", - "Precision", - null, - recall, - precision); + //Chart creation + XYChart chart = new XYChartBuilder().width(600).height(400) + .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision") + .build(); + + //P / (P + N) computation + int pos_ex = Utils.countPositiveExamples(points); + int neg_ex = Utils.countNegativeExamples(points); + double y_dotted_position = pos_ex / (double) (pos_ex + neg_ex); + //Set up graph to visualize random classification + double[] dotted_x = new double[] {0, 1}; + double[] dotted_y = new double[] {y_dotted_position, y_dotted_position}; + XYSeries dotted = chart.addSeries("Random Classification", + dotted_x, dotted_y); + dotted.setLineStyle(SeriesLines.DASH_DASH); + dotted.setMarker(SeriesMarkers.NONE); + + //Set up graph to visualize P-R curve + XYSeries pr = chart.addSeries("P-R curve", recall, precision); + pr.setMarker(SeriesMarkers.NONE); + //Save graph as PNG try { BitmapEncoder.saveBitmapWithDPI(chart, filename, diff --git a/src/main/java/be/cylab/java/roc/Roc.java b/src/main/java/be/cylab/java/roc/Roc.java index 02faa51..a8a1de1 100644 --- a/src/main/java/be/cylab/java/roc/Roc.java +++ b/src/main/java/be/cylab/java/roc/Roc.java @@ -2,8 +2,11 @@ package be.cylab.java.roc; import com.opencsv.CSVReader; import org.knowm.xchart.BitmapEncoder; -import org.knowm.xchart.QuickChart; import org.knowm.xchart.XYChart; +import org.knowm.xchart.XYChartBuilder; +import org.knowm.xchart.XYSeries; +import org.knowm.xchart.style.lines.SeriesLines; +import org.knowm.xchart.style.markers.SeriesMarkers; import java.io.FileInputStream; import java.io.FileNotFoundException; @@ -141,12 +144,22 @@ public class Roc { true_detection[i] = roc_coordinates.get(i).getYAxis(); false_alarm[i] = roc_coordinates.get(i).getXAxis(); } - XYChart chart = QuickChart.getChart("Roc curve", - "False Alarm", - "True Detection", - null, - false_alarm, - true_detection); + //Chart creation + XYChart chart = new XYChartBuilder().width(600).height(400) + .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision") + .build(); + + //Set up graph to visualize random classification. + double[] dotted_x = new double[]{0, 1}; + double[] dotted_y = new double[]{0, 1}; + XYSeries dotted = chart.addSeries("Random Classification", + dotted_x, dotted_y); + dotted.setLineStyle(SeriesLines.DASH_DASH); + dotted.setMarker(SeriesMarkers.NONE); + + //Set up graph to visualize P-R curve + XYSeries pr = chart.addSeries("ROC curve", false_alarm, true_detection); + pr.setMarker(SeriesMarkers.NONE); try { BitmapEncoder.saveBitmapWithDPI(chart, filename, -- GitLab