Skip to content
Snippets Groups Projects
Commit da2e9ff7 authored by a.croix's avatar a.croix
Browse files

Add dotted line for random classification on P-R and ROC curves

parent f3e963d1
No related branches found
No related tags found
1 merge request!1Add dotted line for random classification on P-R and ROC curves
Pipeline #2698 passed
This commit is part of merge request !1. Comments created here will be created in the context of that merge request.
...@@ -2,8 +2,11 @@ package be.cylab.java.roc; ...@@ -2,8 +2,11 @@ package be.cylab.java.roc;
import com.opencsv.CSVReader; import com.opencsv.CSVReader;
import org.knowm.xchart.BitmapEncoder; import org.knowm.xchart.BitmapEncoder;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XYChart; import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
...@@ -131,16 +134,32 @@ public class PRCurve { ...@@ -131,16 +134,32 @@ public class PRCurve {
List<CurveCoordinates> pr_coordinates = computePrcPoints(); List<CurveCoordinates> pr_coordinates = computePrcPoints();
double[] precision = new double[pr_coordinates.size()]; double[] precision = new double[pr_coordinates.size()];
double[] recall = new double[pr_coordinates.size()]; double[] recall = new double[pr_coordinates.size()];
//Create arrays to store Recall-Precision coordinates
for (int i = 0; i < pr_coordinates.size(); i++) { for (int i = 0; i < pr_coordinates.size(); i++) {
precision[i] = pr_coordinates.get(i).getYAxis(); precision[i] = pr_coordinates.get(i).getYAxis();
recall[i] = pr_coordinates.get(i).getXAxis(); recall[i] = pr_coordinates.get(i).getXAxis();
} }
XYChart chart = QuickChart.getChart("PR-Curve", //Chart creation
"Recall", XYChart chart = new XYChartBuilder().width(600).height(400)
"Precision", .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
null, .build();
recall,
precision); //P / (P + N) computation
int pos_ex = Utils.countPositiveExamples(points);
int neg_ex = Utils.countNegativeExamples(points);
double y_dotted_position = pos_ex / (double) (pos_ex + neg_ex);
//Set up graph to visualize random classification
double[] dotted_x = new double[] {0, 1};
double[] dotted_y = new double[] {y_dotted_position, y_dotted_position};
XYSeries dotted = chart.addSeries("Random Classification",
dotted_x, dotted_y);
dotted.setLineStyle(SeriesLines.DASH_DASH);
dotted.setMarker(SeriesMarkers.NONE);
//Set up graph to visualize P-R curve
XYSeries pr = chart.addSeries("P-R curve", recall, precision);
pr.setMarker(SeriesMarkers.NONE);
//Save graph as PNG
try { try {
BitmapEncoder.saveBitmapWithDPI(chart, BitmapEncoder.saveBitmapWithDPI(chart,
filename, filename,
......
...@@ -2,8 +2,11 @@ package be.cylab.java.roc; ...@@ -2,8 +2,11 @@ package be.cylab.java.roc;
import com.opencsv.CSVReader; import com.opencsv.CSVReader;
import org.knowm.xchart.BitmapEncoder; import org.knowm.xchart.BitmapEncoder;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XYChart; import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
...@@ -141,12 +144,22 @@ public class Roc { ...@@ -141,12 +144,22 @@ public class Roc {
true_detection[i] = roc_coordinates.get(i).getYAxis(); true_detection[i] = roc_coordinates.get(i).getYAxis();
false_alarm[i] = roc_coordinates.get(i).getXAxis(); false_alarm[i] = roc_coordinates.get(i).getXAxis();
} }
XYChart chart = QuickChart.getChart("Roc curve", //Chart creation
"False Alarm", XYChart chart = new XYChartBuilder().width(600).height(400)
"True Detection", .title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
null, .build();
false_alarm,
true_detection); //Set up graph to visualize random classification.
double[] dotted_x = new double[]{0, 1};
double[] dotted_y = new double[]{0, 1};
XYSeries dotted = chart.addSeries("Random Classification",
dotted_x, dotted_y);
dotted.setLineStyle(SeriesLines.DASH_DASH);
dotted.setMarker(SeriesMarkers.NONE);
//Set up graph to visualize P-R curve
XYSeries pr = chart.addSeries("ROC curve", false_alarm, true_detection);
pr.setMarker(SeriesMarkers.NONE);
try { try {
BitmapEncoder.saveBitmapWithDPI(chart, BitmapEncoder.saveBitmapWithDPI(chart,
filename, filename,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment