From dfec02bdf9e4399183ea67144093250d68822f63 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Thu, 8 Aug 2019 15:46:59 +0200 Subject: [PATCH] Improve tests + new tests --- .../java/wowa/training/AbstractSolution.java | 24 ++++- .../be/cylab/java/wowa/training/Example.java | 18 ++-- .../be/cylab/java/wowa/training/Utils.java | 27 +----- src/main/resources/data_test.json | 1 + src/main/resources/expected_test.json | 1 + .../cylab/java/wowa/training/UtilsTest.java | 87 ++++++++++++++++++- 6 files changed, 118 insertions(+), 40 deletions(-) create mode 100644 src/main/resources/data_test.json create mode 100644 src/main/resources/expected_test.json diff --git a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java index 9184d27..01d254b 100644 --- a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java +++ b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java @@ -2,6 +2,7 @@ package be.cylab.java.wowa.training; import be.cylab.java.roc.Roc; import be.cylab.java.roc.RocCoordinates; +import info.debatty.java.aggregation.WOWA; import java.text.DateFormat; import java.text.SimpleDateFormat; @@ -113,7 +114,7 @@ public abstract class AbstractSolution } boolean[] true_alert = Utils.convertExpectedToBooleanArrayTrueAlert(expected); - double[] score = Utils.computeWOWAScoreWithData(this, data); + double[] score = this.computeWOWAScoreWithData(data); Roc roc = new Roc(score, true_alert); List<RocCoordinates> coordinates = roc.computeRocPoints(); if (save_on_csv) { @@ -143,7 +144,7 @@ public abstract class AbstractSolution = Utils.convertExpectedToBooleanArrayTrueAlert(expected); //For each elements in dataset, compute the wowa function. - double[] score = Utils.computeWOWAScoreWithData(this, data); + double[] score = this.computeWOWAScoreWithData(data); Roc roc = new Roc(score, true_alert); return roc.computeAUC(); } @@ -164,6 +165,25 @@ public abstract class AbstractSolution return computeAUC(data, expected); } + /** + * @param data + * @return + */ + double[] computeWOWAScoreWithData( + final List<List<Double>> data) { + double[] score = new double[data.size()]; + double[] weights_w + = Utils.convertListDoubleToArrayDouble(this.weights_w); + double[] weights_p + = Utils.convertListDoubleToArrayDouble(this.weights_p); + WOWA wowa = new WOWA(weights_w, weights_p); + for (int i = 0; i < data.size(); i++) { + score[i] = wowa.aggregate( + Utils.convertListDoubleToArrayDouble(data.get(i))); + } + return score; + } + /** * @param solution * @return int diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index d1055c0..f601d3e 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -1,8 +1,6 @@ package be.cylab.java.wowa.training; -import java.util.HashMap; -import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.logging.SimpleFormatter; @@ -74,15 +72,15 @@ public final class Example { + (end_time - start_time) / 1000 + " seconds"); System.out.println("Run Cross validation"); - HashMap<AbstractSolution, Double> solutions = trainer.runKFold( - data_file, - expected_file, - 10, - 10); + //HashMap<AbstractSolution, Double> solutions = trainer.runKFold( + // data_file, + // expected_file, + // 10, + // 10); - for (Map.Entry val : solutions.entrySet()) { - System.out.println(val); - } + //for (Map.Entry val : solutions.entrySet()) { + // System.out.println(val); + //} } diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 1bdfd5b..0d72094 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -2,7 +2,6 @@ package be.cylab.java.wowa.training; import com.owlike.genson.GenericType; import com.owlike.genson.Genson; -import info.debatty.java.aggregation.WOWA; import org.apache.commons.lang3.ArrayUtils; import java.io.IOException; @@ -149,12 +148,13 @@ final class Utils { String expected_json = Utils.readFileToString(filename); List<Double> expected = genson.deserialize(expected_json, new GenericType<List<Double>>() { - }); + }); return expected; } /** * Method to read file and convert into a String. + * * @param filename * @return */ @@ -170,28 +170,6 @@ final class Utils { } - - /** - * @param solution - * @param data - * @return - */ - static double[] computeWOWAScoreWithData( - final AbstractSolution solution, - final List<List<Double>> data) { - double[] score = new double[data.size()]; - double[] weights_w - = convertListDoubleToArrayDouble(solution.getWeightsW()); - double[] weights_p - = convertListDoubleToArrayDouble(solution.getWeightsP()); - WOWA wowa = new WOWA(weights_w, weights_p); - for (int i = 0; i < data.size(); i++) { - score[i] = - wowa.aggregate(convertListDoubleToArrayDouble(data.get(i))); - } - return score; - } - /** * @param expected * @return @@ -227,7 +205,6 @@ final class Utils { } /** - * * @param size * @return */ diff --git a/src/main/resources/data_test.json b/src/main/resources/data_test.json new file mode 100644 index 0000000..10af600 --- /dev/null +++ b/src/main/resources/data_test.json @@ -0,0 +1 @@ +[[0.0,0.0,0.6642622676660601,0.1311475409836066,0.019230769230769232],[0.0,0.0,0.6942392419583658,0.13888888888888884,0.0],[0.0,0.0,0.30981001256541474,0.15000000000000002,0.1923076923076923],[0.0,0.0,0.5102843604339973,0.15384615384615383,0.0],[0.0,0.0,0.489254930245176,0.15384615384615383,0.0],[0.0,0.0,0.6390525163016084,0.1428571428571429,0.0],[0.0,0.0,0.5187108597728124,0.13953488372093026,0.14743589743589744],[0.0,0.0,0.44489189824598674,0.15294117647058825,0.0],[0.0,0.0,0.8168535975077212,0.1717171717171717,0.0]] diff --git a/src/main/resources/expected_test.json b/src/main/resources/expected_test.json new file mode 100644 index 0000000..81ddee6 --- /dev/null +++ b/src/main/resources/expected_test.json @@ -0,0 +1 @@ +[0.6642622676660601,0.1311475409836066,0.019230769230769232,0.0,0.5102843604339973,0.15384615384615383,0.0,0.0,0.6390525163016084,0.1428571428571429] \ No newline at end of file diff --git a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java index e389ff7..0f41032 100644 --- a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java +++ b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java @@ -5,6 +5,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; @@ -71,7 +72,47 @@ class UtilsTest { null_weights.add(0.0); assertThrows(IllegalArgumentException.class, () -> {Utils.normalizeWeights(null_weights);}); + } + + @Test + void testConvertJsonToDataForTrainer() { + List<List<Double>> data = Utils.convertJsonToDataForTrainer("./src/main/resources/data_test.json"); + List<List<Double>> assertion = new ArrayList<>(); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.6642622676660601,0.1311475409836066,0.019230769230769232))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.6942392419583658,0.13888888888888884,0.0))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.30981001256541474,0.15000000000000002,0.1923076923076923))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.5102843604339973,0.15384615384615383,0.0))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.489254930245176,0.15384615384615383,0.0))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.6390525163016084,0.1428571428571429,0.0))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.5187108597728124,0.13953488372093026,0.14743589743589744))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.44489189824598674,0.15294117647058825,0.0))); + assertion.add(new ArrayList<>(Arrays.asList(0.0,0.0,0.8168535975077212,0.1717171717171717,0.0))); + assertEquals(assertion.size(), data.size()); + for (int i = 0; i < assertion.size(); i++) { + for (int j = 0; j < assertion.get(i).size(); j++) { + assertEquals(assertion.get(i).get(j), data.get(i).get(j)); + } + } + } + @Test + void testConvertJsonToExpectedForTrainer() { + List<Double> expected = Utils.convertJsonToExpectedForTrainer("./src/main/resources/expected_test.json"); + List<Double> assertion = new ArrayList<>(); + assertion.add(0.6642622676660601); + assertion.add(0.1311475409836066); + assertion.add(0.019230769230769232); + assertion.add(0.0); + assertion.add(0.5102843604339973); + assertion.add(0.15384615384615383); + assertion.add(0.0); + assertion.add(0.0); + assertion.add(0.6390525163016084); + assertion.add(0.1428571428571429); + assertEquals(assertion.size(), expected.size()); + for (int i = 0; i < assertion.size(); i++) { + assertEquals(assertion.get(i), expected.get(i)); + } } @Test @@ -128,19 +169,59 @@ class UtilsTest { } - @Test - void testComputeWOWAScoreWithData() { - } @Test void testConvertExpectedToBooleanArrayTrueAlert() { + List<Double> test = new ArrayList<>(); + test.add(0.0); + test.add(0.0); + test.add(1.0); + test.add(0.0); + test.add(1.0); + test.add(1.0); + test.add(0.0); + + boolean[] expected = {false, false, true, false, true, true, false}; + boolean[] test_converted = Utils.convertExpectedToBooleanArrayTrueAlert(test); + assertEquals(expected.length, test_converted.length); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], test_converted[i]); + } + + List<Double> testException = new ArrayList<>(); + testException.add(0.0); + testException.add(0.1); + testException.add(1.0); + testException.add(0.0); + testException.add(1.0); + testException.add(1.0); + testException.add(0.0); + assertThrows(IllegalStateException.class, () -> {Utils.convertExpectedToBooleanArrayTrueAlert(testException);}); } @Test void testConvertListDoubleToArrayDouble() { + List<Double> weights = new ArrayList<>(); + weights.add(0.365); + weights.add(0.658); + weights.add(0.14575); + weights.add(0.84547); + weights.add(0.65442485); + + double[] weights_array = Utils.convertListDoubleToArrayDouble(weights); + assertEquals(weights.size(), weights_array.length); + for (int i = 0; i < weights.size(); i++) { + assertEquals((double)weights.get(i), weights_array[i]); + } } @Test void testInitializeListWithZeroValues() { + int size = 12; + List<Double> zero_values = Utils.initializeListWithZeroValues(size); + assertEquals(size, zero_values.size()); + for (Double value : zero_values) { + assertEquals(0.0, (double)value); + } } } \ No newline at end of file -- GitLab