From 2802984fc99387ef3494e1b77d647463cbc8e22a Mon Sep 17 00:00:00 2001
From: Alex <croix.alexandre@gmail.com>
Date: Wed, 21 Aug 2019 14:49:19 +0200
Subject: [PATCH] Add last unit tests + add hashCode and equals method in
 AbstractSolution class

---
 .../java/wowa/training/AbstractSolution.java  |  55 ++++-
 .../be/cylab/java/wowa/training/Trainer.java  |  47 ++++-
 .../cylab/java/wowa/training/TrainerTest.java | 193 +++++++++++++++---
 3 files changed, 256 insertions(+), 39 deletions(-)

diff --git a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java
index 8cf3966..6ed5220 100644
--- a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java
+++ b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java
@@ -6,11 +6,7 @@ import info.debatty.java.aggregation.WOWA;
 
 import java.text.DateFormat;
 import java.text.SimpleDateFormat;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Date;
-import java.util.List;
-import java.util.Random;
+import java.util.*;
 
 
 /**
@@ -102,8 +98,7 @@ public abstract class AbstractSolution
     /**
      * @param data
      * @param expected
-     * @return
-     * Have to be test in each child class and not in the absract class.
+     * @return Have to be test in each child class and not in the absract class.
      */
     final List<RocCoordinates> computeRocPoints(
             final List<List<Double>> data,
@@ -129,6 +124,7 @@ public abstract class AbstractSolution
 
     /**
      * Method to compute AUC with List as argument.
+     *
      * @param data
      * @param expected
      * @return
@@ -152,6 +148,7 @@ public abstract class AbstractSolution
 
     /**
      * Method to compute AUC with filename as arugument.
+     *
      * @param filename_data
      * @param filename_expected
      * @return
@@ -202,6 +199,46 @@ public abstract class AbstractSolution
         }
     }
 
+    /**
+     * @param solution
+     * @return
+     */
+    @Override
+    public boolean equals(final Object solution) {
+        if (this == solution) {
+            return true;
+        }
+        if (solution == null) {
+            return false;
+        }
+        if (this.getClass() != solution.getClass()) {
+            return false;
+        }
+        AbstractSolution sol = (AbstractSolution) solution;
+        if (this.getWeightsW().size() != sol.getWeightsW().size()
+                || this.getWeightsP().size() != sol.getWeightsP().size()) {
+            return false;
+        }
+        for (int i = 0; i < sol.getWeightsP().size(); i++) {
+            if (!getWeightsW().get(i).equals(sol.getWeightsW().get(i))
+                    || !getWeightsP().get(i).equals(sol.getWeightsP().get(i))) {
+                return false;
+            }
+        }
+        if (this.getFitnessScore() != sol.getFitnessScore()) {
+            return false;
+        }
+        return true;
+    }
+
+    /**
+     * @return
+     */
+    @Override
+    public final int hashCode() {
+        return Objects.hash(weights_w, weights_p, fitness_score);
+    }
+
     /**
      * Function to normalize SolutionDistance weights.
      * Weights must be between 0 and 1 and the sum of the weight in a vector
@@ -255,8 +292,8 @@ public abstract class AbstractSolution
      * @return
      * @throws java.lang.CloneNotSupportedException if clone is not supported
      */
-    public final SolutionDistance clone() throws CloneNotSupportedException {
-        return (SolutionDistance) super.clone();
+    public final AbstractSolution clone() throws CloneNotSupportedException {
+        return (AbstractSolution) super.clone();
     }
 
 }
diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index b052ec5..4d4f956 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -1,10 +1,7 @@
 package be.cylab.java.wowa.training;
 
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
+import java.util.*;
 import java.util.logging.Level;
 
 /**
@@ -367,6 +364,48 @@ public class Trainer {
         return selected_elements;
     }
 
+    /**
+     * Method used only for tests !!
+     * This method generates random number (tos) by using a seed !
+     * @param solutions
+     * @param selected_elements
+     * @param count
+     * @param seed
+     * @return
+     */
+    final List<AbstractSolution> rouletteWheelSelectionForTestOnly(
+            final List<AbstractSolution> solutions,
+            final List<AbstractSolution> selected_elements,
+            final int count,
+            final int seed) {
+
+        double max = Utils.findMaxDistance(solutions);
+        double min = Utils.findMinDistance(solutions);
+        double sum = Utils.sumTotalDistance(solutions);
+
+        if (solutions.size() < count) {
+            throw new IllegalArgumentException(
+                    "Not enough elements in population to select "
+                            + count + "parents"
+            );
+        }
+        Random rnd = new Random(seed);
+        while (selected_elements.size() < count) {
+            double tos = rnd.nextDouble();
+            double normalized_distance = 0;
+            for (AbstractSolution solution : solutions) {
+                normalized_distance = normalized_distance
+                        + (max + min - solution.getFitnessScore()) / sum;
+                if (normalized_distance > tos) {
+                    selected_elements.add(solution);
+                    solutions.remove(solution);
+                    break;
+                }
+            }
+        }
+        return selected_elements;
+    }
+
     /**
      * Select elements for the next generation.
      * Two elements are randomly selected. The element with the best
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
index f444f8a..bce9edb 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
@@ -3,10 +3,7 @@ package be.cylab.java.wowa.training;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
+import java.util.*;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -145,6 +142,59 @@ class TrainerTest {
 
     }
 
+    @Test
+    void testGenerateQuasiRandomInitialPopulationFewForcedValues() {
+        int population_size = 30;
+        int number_of_random_elements = 15;
+        int number_of_not_random_elements = 15;
+        int number_of_weights = 5;
+        int counter_random = 0;
+        int counter_non_random = 0;
+        boolean[][] memory = new boolean[number_of_weights][number_of_weights];
+        int memory_counter = 0;
+        List<AbstractSolution> population = this.trainer.generateQuasiRandomInitialPopulation(number_of_weights, population_size);
+        assertEquals(population_size, population.size());
+        for(AbstractSolution solution : population) {
+            assertEquals( 1.0, Utils.sumListElements(solution.getWeightsW()), 0.00001);
+            assertEquals( 1.0, Utils.sumListElements(solution.getWeightsP()), 0.00001);
+            assertEquals(5, solution.getWeightsW().size());
+            assertEquals(5, solution.getWeightsP().size());
+            assertEquals(Double.POSITIVE_INFINITY, solution.getFitnessScore());
+            if (solution.getWeightsW().get(0) != 0.0 && solution.getWeightsW().get(0) != 1.0) {
+                counter_non_random++;
+            }
+
+        }
+        for (int i = 0; i < number_of_weights; i++) {
+            for (int j = 0; j < number_of_weights; j++) {
+                List<Double> w = Utils.initializeListWithZeroValues(number_of_weights);
+                List<Double> p = Utils.initializeListWithZeroValues(number_of_weights);
+                w.set(i, 1.0);
+                p.set(j, 1.0);
+                AbstractSolution sol = new SolutionDistance(number_of_weights);
+                sol.setWeightsW(w);
+                sol.setWeightsP(p);
+                for (AbstractSolution solution : population) {
+                    if (solution.getWeightsW().equals(sol.getWeightsW()) && solution.getWeightsP().equals(sol.getWeightsP())) {
+                        counter_random++;
+                        memory[i][j] = !memory[i][j];
+                    }
+                }
+            }
+        }
+        assertEquals(number_of_random_elements, counter_random);
+        assertEquals(number_of_not_random_elements, counter_non_random);
+        for (boolean[] line : memory) {
+            for (boolean el : line) {
+                if (el) {
+                    memory_counter++;
+                }
+            }
+        }
+        assertEquals(number_of_random_elements, memory_counter);
+
+    }
+
     @Test
     void testComputeDistances() {
         List<AbstractSolution> population = new ArrayList<>();
@@ -183,6 +233,64 @@ class TrainerTest {
 
     }
 
+    /**
+     * The test function uses a modified version of rouletteWheelSelection
+     * because of the random feature inside the function.
+     * The test method uses a seed to force a specific "random" generation.
+     */
+    @Test
+    void testRouletteWheelSelection() {
+        int number_of_weights = 5;
+        int population_size = 50;
+        int number_of_selected_elements = 10;
+        List<List<Double>> data = generateData(500, number_of_weights);
+        List<Double> expected = generateExpected(500);
+        Random rnd = new Random(55646);
+        List<AbstractSolution> population = new ArrayList<>();
+        for (int i = 0; i < population_size; i++){
+            population.add(new SolutionDistance(number_of_weights, rnd.nextInt()));
+        }
+        trainer.computeDistances(population, data, expected);
+        Collections.sort(population);
+        List<AbstractSolution> selected = new ArrayList<>();
+        selected = trainer.rouletteWheelSelectionForTestOnly(population, selected, number_of_selected_elements, rnd.nextInt());
+        assertEquals(number_of_selected_elements, selected.size());
+        assertEquals(population_size - number_of_selected_elements, population.size());
+        double[] assertion = {7.622288005169746, 7.272594799246598, 7.147538725780932, 7.347211225721799, 7.554008720137534,
+                6.977912255816629, 7.553481886293156, 7.655488416261908, 7.255009528708453, 7.68296999390714};
+        for (int i = 0; i < selected.size(); i++) {
+            assertEquals(assertion[i], selected.get(i).getFitnessScore());
+        }
+    }
+
+    /**
+     * Test not perfect. It is not possible to test the selected elements because of the random feature of
+     * tournamentSelection.
+     */
+    @Test
+    void testTournamentSelection() {
+        int number_of_weights = 5;
+        int population_size = 50;
+        int number_of_selected_elements = 10;
+        List<List<Double>> data = generateData(500, number_of_weights);
+        List<Double> expected = generateExpected(500);
+        Random rnd = new Random(55646);
+        List<AbstractSolution> population = new ArrayList<>();
+        for (int i = 0; i < population_size; i++){
+            population.add(new SolutionDistance(number_of_weights, rnd.nextInt()));
+        }
+        trainer.computeDistances(population, data, expected);
+        Collections.sort(population);
+        List<AbstractSolution> selected = new ArrayList<>();
+        selected = trainer.tournamentSelection(population, selected, number_of_selected_elements);
+        assertEquals(number_of_selected_elements, selected.size());
+        assertEquals(population_size - number_of_selected_elements, population.size());
+    }
+
+    /**
+     * This method tests if the two best elements of the population are correctly selected for the next generation.
+     * The method SelectParents uses rouletteWheelSelection or tournamentSelection that are test before.
+     */
     @Test
     void testSelectParents() {
         List<AbstractSolution> population = this.trainer.generateInitialPopulation(5, 100);
@@ -202,12 +310,11 @@ class TrainerTest {
 
     }
 
+    /**
+     * This method test reproduce AND crossoverElements methods.
+     */
     @Test
-    void doReproduction() {
-    }
-
-    @Test
-    void reproduce() {
+    void testReproduce() {
         List<AbstractSolution> population = new ArrayList<>();
         for (int i = 0; i < this.trainer.getParameters().getPopulationSize(); i++) {
             AbstractSolution solution = new SolutionDistance(5, 2*i);
@@ -242,28 +349,62 @@ class TrainerTest {
 
     }
 
+    /**
+     * To perform this test we perform 1000 times a mutation (rate = 10) on a
+     * specific population (50 elements) and count for each the number of solution mutated.
+     * We perform a mean of these counts and check if the result is between 4 and 5.
+     */
     @Test
-    void randomlyMutateGenes() {
-        List<AbstractSolution> population = trainer.generateInitialPopulation(5, 100);
-        List<AbstractSolution> cloned_list = new ArrayList<>();
-        cloned_list.addAll(population);
-        List<Double> w = new ArrayList<>();
-        w.add(0.2);
-        w.add(0.2);
-        w.add(0.2);
-        w.add(0.2);
-        w.add(0.2);
-        population.get(0).setWeightsP(w);
-        int cnt = 0;
-        List<AbstractSolution> mutated_population = trainer.randomlyMutateGenes(population);
-        for (int i = 0; i < mutated_population.size(); i++) {
-            System.out.println(population.get(i));
-            System.out.println(cloned_list.get(i));
+    void testRandomlyMutateGenes() {
+        int general_cnt = 0;
+        for (int j = 0; j < 1000; j++) {
+            int number_of_weights = 5;
+            int population_size = 50;
+            int cnt = 0;
+            Random rnd = new Random(55646);
+            List<AbstractSolution> population = new ArrayList<>();
+            for (int i = 0; i < population_size; i++){
+                population.add(new SolutionDistance(number_of_weights, rnd.nextInt()));
+            }
+            List<AbstractSolution> population_copy = new ArrayList<>();
+            rnd = new Random(55646);
+            for (int i = 0; i < population_size; i++){
+                population_copy.add(new SolutionDistance(number_of_weights, rnd.nextInt()));
+            }
+            population = trainer.randomlyMutateGenes(population);
+            assertEquals(population_copy.size(), population.size());
+            for (int i = 0; i < population_size; i++) {
+                if (!population.get(i).equals(population_copy.get(i))) {
+                    cnt++;
+                }
+            }
+            general_cnt = general_cnt + cnt;
         }
-        System.out.println(cnt);
+        assertTrue(general_cnt / (double)1000 > 4.0);
+        assertTrue(general_cnt / (double)1000 < 5.0);
+
     }
 
+    /**
+     * This method doesn't require a specific test.
+     * Indeed, this method calls generateInitialPopulation or
+     * generateQuasiRandomInitialPopulation method followed by the method
+     * computeDistances.
+     * All these methods are tested individually before
+     */
+    @Test
+    void testGenerateInitialPopulationAndComputeDistances() {
+    }
 
+    /**
+     * This method doesn't require a specific test.
+     * Indeed, this method calls selectParents, then doReproduction and finally
+     * randomlyMutatedGenes.
+     * All these lethods are tested individually
+     */
+    @Test
+    void testPerformReproduction() {
+    }
     static List<List<Double>> generateData(final int size, final int weight_number) {
         Random rnd = new Random(5489);
         List<List<Double>> data = new ArrayList<>();
-- 
GitLab