diff --git a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java index 876201fb651e7fa8deab49fefa54122600d38196..8cf396645468cb88754e5d89de044e4e943ce383 100644 --- a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java +++ b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java @@ -103,6 +103,7 @@ public abstract class AbstractSolution * @param data * @param expected * @return + * Have to be test in each child class and not in the absract class. */ final List<RocCoordinates> computeRocPoints( final List<List<Double>> data, diff --git a/src/test/java/be/cylab/java/wowa/training/AbstractSolutionTest.java b/src/test/java/be/cylab/java/wowa/training/AbstractSolutionTest.java index 4812e8e6a6ce02a26df28bae1958baa8d7b36a2d..27d1abb5bc1d4d446cfaab7f703961ebc8c5d402 100644 --- a/src/test/java/be/cylab/java/wowa/training/AbstractSolutionTest.java +++ b/src/test/java/be/cylab/java/wowa/training/AbstractSolutionTest.java @@ -1,7 +1,150 @@ package be.cylab.java.wowa.training; -import static org.junit.jupiter.api.Assertions.*; +import be.cylab.java.roc.RocCoordinates; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; class AbstractSolutionTest { + @Test + void randomlyMutateWithProbability() throws CloneNotSupportedException { + AbstractSolution solution = new SolutionDistance(5, 2548); + AbstractSolution copy = solution.clone(); + solution.randomlyMutateWithProbability(1); + double actualW = Utils.sumListElements(solution.getWeightsW()); + double actualP = Utils.sumListElements(solution.getWeightsP()); + assertEquals(1, actualW, 0.00001); + assertEquals(1, actualP, 0.00001); + assertFalse(copy.equals(solution)); + + } + @Test + /** + * The method is the same for all AbstractSolution children + * but it is not possible to instantiate an AbstractSolution object directly. + * The same test is performed with SolutionDistance and AUCSolution + */ + void testComputeRocPoints() { + //Test with SolutionDistance + SolutionDistance solution = new SolutionDistance(5, 1234); + List<List<Double>> data = generateData(20, 5); + List<Double> expected = generateExpectedForRoc(20); + List<RocCoordinates> coord = solution.computeRocPoints(data, expected,false); + List<RocCoordinates> roc_assert = generateRocCoordinatesAssertion(); + for (int i = 0; i < coord.size(); i++) { + assertEquals(roc_assert.get(i).getFalseAlarm(), coord.get(i).getFalseAlarm()); + assertEquals(roc_assert.get(i).getTrueDetection(), coord.get(i).getTrueDetection()); + //System.out.println("roc_assert.add(new RocCoordinates("+coord.get(i).getFalseAlarm() +"," + coord.get(i).getTrueDetection()+"));"); + } + //Test with SolutionAUC + SolutionAUC solution2 = new SolutionAUC(5, 1234); + List<List<Double>> data2 = generateData(20, 5); + List<Double> expected2 = generateExpectedForRoc(20); + List<RocCoordinates> coord2 = solution2.computeRocPoints(data2, expected2,false); + List<RocCoordinates> roc_assert2 = generateRocCoordinatesAssertion(); + for (int i = 0; i < coord2.size(); i++) { + assertEquals(roc_assert2.get(i).getFalseAlarm(), coord2.get(i).getFalseAlarm()); + assertEquals(roc_assert2.get(i).getTrueDetection(), coord2.get(i).getTrueDetection()); + //System.out.println("roc_assert.add(new RocCoordinates("+coord.get(i).getFalseAlarm() +"," + coord.get(i).getTrueDetection()+"));"); + } + } + + @Test + void testComputeAUC() { + SolutionDistance solution = new SolutionDistance(5, 1234); + List<List<Double>> data = generateData(20, 5); + List<Double> expected = generateExpectedForRoc(20); + double auc = solution.computeAUC(data, expected); + assertEquals(0.5119047619047619, auc); + + SolutionAUC solution2 = new SolutionAUC(5, 1234); + List<List<Double>> data2 = generateData(20, 5); + List<Double> expected2 = generateExpectedForRoc(20); + double auc2 = solution2.computeAUC(data2, expected2); + assertEquals(0.5119047619047619, auc2); + } + + + @Test void testComputeWOWAScoreWithData() { + SolutionDistance solution = new SolutionDistance(5, 1234); + List<List<Double>> data = generateData(10, 5); + double[] score = solution.computeWOWAScoreWithData(data); + double[] assertion = {0.38644819664932933, 0.4198252469435713, 0.5098847192488073, + 0.508683988618759, 0.7021905336693349, 0.5362799073045719, 0.5247915156715276, + 0.7171714189639167, 0.6161379941144322, 0.6181842437321532}; + assertEquals(assertion.length, score.length); + for (int i = 0; i < score.length; i++) { + assertEquals(assertion[i], score[i]); + } + + } + static List<List<Double>> generateData(final int size, final int weight_number) { + Random rnd = new Random(5489); + List<List<Double>> data = new ArrayList<>(); + for (int i = 0; i < size; i++) { + List<Double> vector = new ArrayList<>(); + for (int j = 0; j < weight_number; j++) { + vector.add(rnd.nextDouble()); + } + data.add(vector); + } + return data; + } + + + + static List<Double> generateExpected(final int size) { + Random rnd = new Random(5768); + List<Double> expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + expected.add(rnd.nextDouble()); + } + return expected; + } + + static List<Double> generateExpectedForRoc(final int size) { + Random rnd = new Random(5768); + List<Double> expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + if (rnd.nextDouble() <= 0.5) { + expected.add(new Double(0.0)); + } else { + expected.add(new Double(1.0)); + } + } + return expected; + } + + private static List<RocCoordinates> generateRocCoordinatesAssertion() { + List<RocCoordinates> roc_assert = new ArrayList<>(); + roc_assert.add(new RocCoordinates(0.0,0.0)); + roc_assert.add(new RocCoordinates(0.0,0.07142857142857142)); + roc_assert.add(new RocCoordinates(0.16666666666666666,0.07142857142857142)); + roc_assert.add(new RocCoordinates(0.16666666666666666,0.14285714285714285)); + roc_assert.add(new RocCoordinates(0.16666666666666666,0.21428571428571427)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.21428571428571427)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.2857142857142857)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.35714285714285715)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.42857142857142855)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.5)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.5714285714285714)); + roc_assert.add(new RocCoordinates(0.3333333333333333,0.6428571428571429)); + roc_assert.add(new RocCoordinates(0.5,0.6428571428571429)); + roc_assert.add(new RocCoordinates(0.6666666666666666,0.6428571428571429)); + roc_assert.add(new RocCoordinates(0.8333333333333334,0.6428571428571429)); + roc_assert.add(new RocCoordinates(0.8333333333333334,0.7142857142857143)); + roc_assert.add(new RocCoordinates(0.8333333333333334,0.7857142857142857)); + roc_assert.add(new RocCoordinates(0.8333333333333334,0.8571428571428571)); + roc_assert.add(new RocCoordinates(1.0,0.8571428571428571)); + roc_assert.add(new RocCoordinates(1.0,0.9285714285714286)); + roc_assert.add(new RocCoordinates(1.0,1.0)); + return roc_assert; + } + } \ No newline at end of file diff --git a/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java b/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java index 713052c478287690a3ce61248e01aea6cbeae71f..80860cd701f95f68b824163acb295642b5eb739b 100644 --- a/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java +++ b/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java @@ -8,11 +8,10 @@ import java.util.List; import java.util.Random; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; class SolutionDistanceTest { - @org.junit.jupiter.api.Test + @Test public void testGenerateSolutionDistance() { SolutionDistance solution = new SolutionDistance(5, 1234); double actualW = Utils.sumListElements(solution.getWeightsW()); @@ -23,7 +22,7 @@ class SolutionDistanceTest { assertEquals(5, solution.getWeightsP().size()); } - @org.junit.jupiter.api.Test + @Test public void computeScoreTo() { SolutionDistance solution = new SolutionDistance(5,1234); List<List<Double>> data = generateData(20, 5); @@ -32,18 +31,6 @@ class SolutionDistanceTest { assertEquals(1.5925246317727138, solution.getFitnessScore()); } - @Test - void randomlyMutateWithProbability() throws CloneNotSupportedException { - SolutionDistance solution = new SolutionDistance(5, 2548); - SolutionDistance copy = solution.clone(); - solution.randomlyMutateWithProbability(1); - double actualW = Utils.sumListElements(solution.getWeightsW()); - double actualP = Utils.sumListElements(solution.getWeightsP()); - assertEquals(1, actualW, 0.00001); - assertEquals(1, actualP, 0.00001); - assertFalse(copy.equals(solution)); - - } @Test void testComputeRocPoints() { SolutionDistance solution = new SolutionDistance(5, 1234); @@ -56,7 +43,6 @@ class SolutionDistanceTest { assertEquals(roc_assert.get(i).getTrueDetection(), coord.get(i).getTrueDetection()); //System.out.println("roc_assert.add(new RocCoordinates("+coord.get(i).getFalseAlarm() +"," + coord.get(i).getTrueDetection()+"));"); } - }