Skip to content
Snippets Groups Projects
Commit da2e9ff7 authored by a.croix's avatar a.croix
Browse files

Add dotted line for random classification on P-R and ROC curves

parent f3e963d1
No related branches found
No related tags found
1 merge request!1Add dotted line for random classification on P-R and ROC curves
Pipeline #2698 passed
This commit is part of merge request !1. Comments created here will be created in the context of that merge request.
......@@ -2,8 +2,11 @@ package be.cylab.java.roc;
import com.opencsv.CSVReader;
import org.knowm.xchart.BitmapEncoder;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
......@@ -131,16 +134,32 @@ public class PRCurve {
List<CurveCoordinates> pr_coordinates = computePrcPoints();
double[] precision = new double[pr_coordinates.size()];
double[] recall = new double[pr_coordinates.size()];
//Create arrays to store Recall-Precision coordinates
for (int i = 0; i < pr_coordinates.size(); i++) {
precision[i] = pr_coordinates.get(i).getYAxis();
recall[i] = pr_coordinates.get(i).getXAxis();
}
XYChart chart = QuickChart.getChart("PR-Curve",
"Recall",
"Precision",
null,
recall,
precision);
//Chart creation
XYChart chart = new XYChartBuilder().width(600).height(400)
.title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
.build();
//P / (P + N) computation
int pos_ex = Utils.countPositiveExamples(points);
int neg_ex = Utils.countNegativeExamples(points);
double y_dotted_position = pos_ex / (double) (pos_ex + neg_ex);
//Set up graph to visualize random classification
double[] dotted_x = new double[] {0, 1};
double[] dotted_y = new double[] {y_dotted_position, y_dotted_position};
XYSeries dotted = chart.addSeries("Random Classification",
dotted_x, dotted_y);
dotted.setLineStyle(SeriesLines.DASH_DASH);
dotted.setMarker(SeriesMarkers.NONE);
//Set up graph to visualize P-R curve
XYSeries pr = chart.addSeries("P-R curve", recall, precision);
pr.setMarker(SeriesMarkers.NONE);
//Save graph as PNG
try {
BitmapEncoder.saveBitmapWithDPI(chart,
filename,
......
......@@ -2,8 +2,11 @@ package be.cylab.java.roc;
import com.opencsv.CSVReader;
import org.knowm.xchart.BitmapEncoder;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
......@@ -141,12 +144,22 @@ public class Roc {
true_detection[i] = roc_coordinates.get(i).getYAxis();
false_alarm[i] = roc_coordinates.get(i).getXAxis();
}
XYChart chart = QuickChart.getChart("Roc curve",
"False Alarm",
"True Detection",
null,
false_alarm,
true_detection);
//Chart creation
XYChart chart = new XYChartBuilder().width(600).height(400)
.title("PR-Curve").xAxisTitle("Recall").yAxisTitle("Precision")
.build();
//Set up graph to visualize random classification.
double[] dotted_x = new double[]{0, 1};
double[] dotted_y = new double[]{0, 1};
XYSeries dotted = chart.addSeries("Random Classification",
dotted_x, dotted_y);
dotted.setLineStyle(SeriesLines.DASH_DASH);
dotted.setMarker(SeriesMarkers.NONE);
//Set up graph to visualize P-R curve
XYSeries pr = chart.addSeries("ROC curve", false_alarm, true_detection);
pr.setMarker(SeriesMarkers.NONE);
try {
BitmapEncoder.saveBitmapWithDPI(chart,
filename,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment