From c93d8e770e6320c2aaa9e94d13935be9e83342fb Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Tue, 27 Aug 2019 14:20:29 +0200 Subject: [PATCH] Add method to compute statistical data and save results in files txt --- .../be/cylab/java/wowa/training/Example.java | 18 +- .../be/cylab/java/wowa/training/Trainer.java | 6 + .../be/cylab/java/wowa/training/Utils.java | 174 +++++++++++++++++- 3 files changed, 189 insertions(+), 9 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index f601d3e..d1055c0 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -1,6 +1,8 @@ package be.cylab.java.wowa.training; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.logging.SimpleFormatter; @@ -72,15 +74,15 @@ public final class Example { + (end_time - start_time) / 1000 + " seconds"); System.out.println("Run Cross validation"); - //HashMap<AbstractSolution, Double> solutions = trainer.runKFold( - // data_file, - // expected_file, - // 10, - // 10); + HashMap<AbstractSolution, Double> solutions = trainer.runKFold( + data_file, + expected_file, + 10, + 10); - //for (Map.Entry val : solutions.entrySet()) { - // System.out.println(val); - //} + for (Map.Entry val : solutions.entrySet()) { + System.out.println(val); + } } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 01ff230..d533978 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -125,6 +125,12 @@ public class Trainer { testing.getData(), testing.getExpected()); + Utils.computeStatisticalInformation( + testing.getData(), + testing.getExpected(), + 100, + "Fold_" + (i + 1) + ".txt", + sol); map.put(sol, score); } return map; diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 362b56e..2ebce94 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -2,9 +2,10 @@ package be.cylab.java.wowa.training; import com.owlike.genson.GenericType; import com.owlike.genson.Genson; +import info.debatty.java.aggregation.WOWA; import org.apache.commons.lang3.ArrayUtils; -import java.io.IOException; +import java.io.*; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; @@ -242,4 +243,175 @@ final class Utils { } } + static void computeStatisticalInformation( + final List<List<Double>> data, + final List<Double> expected, + final int nbr_values, + final String filename, + final AbstractSolution solution) { + int wowa_true_positive_counter = 0; + int wowa_true_negative_counter = 0; + int average_true_positive_counter = 0; + int average_true_negative_counter = 0; + + int wowa_false_positive_counter = 0; + int wowa_false_negative_counter = 0; + int average_false_positive_counter = 0; + int average_false_negative_counter = 0; + + double iteration = 1 / (double) nbr_values; + + + for (double trigger_value = 0.00; + trigger_value < 1; trigger_value += iteration) { + for (int j = 0; j < data.size(); j++) { + WOWA wowa = new WOWA(Utils.convertListDoubleToArrayDouble( + solution.weights_w), + Utils.convertListDoubleToArrayDouble( + solution.weights_p)); + double wowa_result = wowa.aggregate( + Utils.convertListDoubleToArrayDouble(data.get(j))); + double average = Utils.sumListElements(data.get(j)) + / data.get(j).size(); + if (expected.get(j) == 1 && average > trigger_value && wowa_result > trigger_value) { + average_true_positive_counter++; + wowa_true_positive_counter++; + } else if (expected.get(j) == 1 && average < trigger_value && wowa_result > trigger_value) { + wowa_true_positive_counter++; + average_false_negative_counter++; + } else if (expected.get(j) == 1 && average > trigger_value && wowa_result < trigger_value) { + average_true_positive_counter++; + wowa_false_negative_counter++; + } else if (expected.get(j) == 1 && average < trigger_value && wowa_result < trigger_value) { + average_false_negative_counter++; + wowa_false_negative_counter++; + } else if (expected.get(j) == 0 && average < trigger_value && wowa_result < trigger_value) { + average_true_negative_counter++; + wowa_true_negative_counter++; + } else if (expected.get(j) == 0 && average < trigger_value && wowa_result > trigger_value) { + average_true_negative_counter++; + wowa_false_positive_counter++; + } else if (expected.get(j) == 0 && average > trigger_value && wowa_result < trigger_value) { + wowa_true_negative_counter++; + average_false_positive_counter++; + } else if (expected.get(j) == 0 && average > trigger_value && wowa_result > trigger_value) { + average_false_positive_counter++; + wowa_false_positive_counter++; + } + } + + //Positive conditions. + int wowa_positive_condition = wowa_true_positive_counter + wowa_false_negative_counter; + int average_positive_condition = average_true_positive_counter + average_false_negative_counter; + + //Negative conditions. + int wowa_negative_condition = wowa_true_negative_counter + wowa_false_positive_counter; + int average_negative_condition = average_true_negative_counter + average_false_positive_counter; + + //True positive rate. + double wowa_true_positive_rate = wowa_true_positive_counter / wowa_positive_condition; + double average_true_positive_rate = average_true_positive_counter / average_positive_condition; + + //True negative rate. + double wowa_true_negative_rate = wowa_true_negative_counter / wowa_negative_condition; + double average_true_negative_rate = average_true_negative_counter / average_negative_condition; + + //False negative rate. + double wowa_false_negative_rate = 1 - wowa_true_positive_rate; + double average_false_negative_rate = 1 - average_true_positive_rate; + + //Flase positive rate. + double wowa_false_positive_rate = 1 - wowa_true_negative_rate; + double average_false_positive_rate = 1 - average_true_negative_rate; + + //Precision (or positive prediction value). + double wowa_precision = wowa_true_positive_counter / (double) (wowa_true_positive_counter + wowa_false_positive_counter); + double average_precision = average_true_positive_counter / (double) (average_true_positive_counter + average_false_positive_counter); + + //Negative prediction value. + double wowa_negative_prediction_value = wowa_true_negative_counter / (double) (wowa_true_negative_counter + wowa_false_negative_counter); + double average_negative_prediction_value = average_true_negative_counter / (double) (average_true_negative_counter + average_false_negative_counter); + + //False discovery rate. + double wowa_false_discovery_rate = 1 - wowa_precision; + double average_false_discovery_rate = 1 - average_precision; + + //False omission rate. + double wowa_false_omission_rate = 1 - wowa_negative_prediction_value; + double average_false_omission_rate = 1 - average_negative_prediction_value; + + //Accuracy. + double wowa_accuracy = (wowa_true_positive_counter + wowa_true_negative_counter) / expected.size(); + double average_accuracy = (average_true_positive_counter + average_true_negative_counter) / expected.size(); + + //F1 score. + double wowa_f_1 = 2 * (wowa_precision * wowa_true_positive_rate) / (wowa_precision + wowa_true_positive_rate); + double average_f_1 = 2 * (average_precision * average_true_positive_rate) / (average_precision + average_true_positive_rate); + + try (FileWriter fw = new FileWriter(filename, true); + BufferedWriter bw = new BufferedWriter(fw); + PrintWriter writer = new PrintWriter(bw)) { + + writer.println("___________________________________________"); + writer.println("|TRIGGER VALUE :" + trigger_value); + writer.println("|_________________________________________|"); + writer.println("WOWA"); + + writer.println("Wowa true positive : " + wowa_true_positive_counter); + writer.println("Wowa true negative : " + wowa_true_negative_counter); + writer.println("Wowa false positive : " + wowa_false_positive_counter); + writer.println("Wowa false negative : " + wowa_false_negative_counter); + + writer.println("Wowa true positive rate : " + wowa_true_positive_rate); + writer.println("Wowa true negative rate : " + wowa_true_negative_rate); + writer.println("Wowa false positive rate : " + wowa_false_positive_rate); + writer.println("Wowa false negative rate : " + wowa_false_negative_rate); + + writer.println("Wowa precision : " + wowa_precision); + writer.println("Wowa negative predictive value : " + wowa_negative_prediction_value); + writer.println("Wowa false discovery rate : " + wowa_false_discovery_rate); + writer.println("Wowa false omission rate : " + wowa_false_omission_rate); + writer.println("Wowa accuracy : " + wowa_accuracy); + writer.println("Wowa F1 score : " + wowa_f_1); + + writer.println("AVERAGE\n"); + writer.println("Average true positive : " + average_true_positive_counter); + writer.println("Average true negative : " + average_true_negative_counter); + writer.println("Average false positive : " + average_false_positive_counter); + writer.println("Average false negative : " + average_false_negative_counter); + + writer.println("Average true positive rate : " + average_true_positive_rate); + writer.println("Average true negative rate : " + average_true_negative_rate); + writer.println("Average false positive rate : " + average_false_positive_rate); + writer.println("Average false negative rate : " + average_false_negative_rate); + + writer.println("Average precision : " + average_precision); + writer.println("Average negative predictive value : " + average_negative_prediction_value); + writer.println("Average false discovery rate : " + average_false_discovery_rate); + writer.println("Average false omission rate : " + average_false_omission_rate); + writer.println("Average accuracy : " + average_accuracy); + writer.println("Average F1 score : " + average_f_1); + + wowa_true_positive_counter = 0; + wowa_true_negative_counter = 0; + average_true_positive_counter = 0; + average_true_negative_counter = 0; + + wowa_false_positive_counter = 0; + wowa_false_negative_counter = 0; + average_false_positive_counter = 0; + average_false_negative_counter = 0; + + } catch (FileNotFoundException e) { + e.printStackTrace(); + } catch (UnsupportedEncodingException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + + } + + } + } -- GitLab