From 2802984fc99387ef3494e1b77d647463cbc8e22a Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Wed, 21 Aug 2019 14:49:19 +0200 Subject: [PATCH] Add last unit tests + add hashCode and equals method in AbstractSolution class --- .../java/wowa/training/AbstractSolution.java | 55 ++++- .../be/cylab/java/wowa/training/Trainer.java | 47 ++++- .../cylab/java/wowa/training/TrainerTest.java | 193 +++++++++++++++--- 3 files changed, 256 insertions(+), 39 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java index 8cf3966..6ed5220 100644 --- a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java +++ b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java @@ -6,11 +6,7 @@ import info.debatty.java.aggregation.WOWA; import java.text.DateFormat; import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Date; -import java.util.List; -import java.util.Random; +import java.util.*; /** @@ -102,8 +98,7 @@ public abstract class AbstractSolution /** * @param data * @param expected - * @return - * Have to be test in each child class and not in the absract class. + * @return Have to be test in each child class and not in the absract class. */ final List<RocCoordinates> computeRocPoints( final List<List<Double>> data, @@ -129,6 +124,7 @@ public abstract class AbstractSolution /** * Method to compute AUC with List as argument. + * * @param data * @param expected * @return @@ -152,6 +148,7 @@ public abstract class AbstractSolution /** * Method to compute AUC with filename as arugument. + * * @param filename_data * @param filename_expected * @return @@ -202,6 +199,46 @@ public abstract class AbstractSolution } } + /** + * @param solution + * @return + */ + @Override + public boolean equals(final Object solution) { + if (this == solution) { + return true; + } + if (solution == null) { + return false; + } + if (this.getClass() != solution.getClass()) { + return false; + } + AbstractSolution sol = (AbstractSolution) solution; + if (this.getWeightsW().size() != sol.getWeightsW().size() + || this.getWeightsP().size() != sol.getWeightsP().size()) { + return false; + } + for (int i = 0; i < sol.getWeightsP().size(); i++) { + if (!getWeightsW().get(i).equals(sol.getWeightsW().get(i)) + || !getWeightsP().get(i).equals(sol.getWeightsP().get(i))) { + return false; + } + } + if (this.getFitnessScore() != sol.getFitnessScore()) { + return false; + } + return true; + } + + /** + * @return + */ + @Override + public final int hashCode() { + return Objects.hash(weights_w, weights_p, fitness_score); + } + /** * Function to normalize SolutionDistance weights. * Weights must be between 0 and 1 and the sum of the weight in a vector @@ -255,8 +292,8 @@ public abstract class AbstractSolution * @return * @throws java.lang.CloneNotSupportedException if clone is not supported */ - public final SolutionDistance clone() throws CloneNotSupportedException { - return (SolutionDistance) super.clone(); + public final AbstractSolution clone() throws CloneNotSupportedException { + return (AbstractSolution) super.clone(); } } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index b052ec5..4d4f956 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -1,10 +1,7 @@ package be.cylab.java.wowa.training; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; +import java.util.*; import java.util.logging.Level; /** @@ -367,6 +364,48 @@ public class Trainer { return selected_elements; } + /** + * Method used only for tests !! + * This method generates random number (tos) by using a seed ! + * @param solutions + * @param selected_elements + * @param count + * @param seed + * @return + */ + final List<AbstractSolution> rouletteWheelSelectionForTestOnly( + final List<AbstractSolution> solutions, + final List<AbstractSolution> selected_elements, + final int count, + final int seed) { + + double max = Utils.findMaxDistance(solutions); + double min = Utils.findMinDistance(solutions); + double sum = Utils.sumTotalDistance(solutions); + + if (solutions.size() < count) { + throw new IllegalArgumentException( + "Not enough elements in population to select " + + count + "parents" + ); + } + Random rnd = new Random(seed); + while (selected_elements.size() < count) { + double tos = rnd.nextDouble(); + double normalized_distance = 0; + for (AbstractSolution solution : solutions) { + normalized_distance = normalized_distance + + (max + min - solution.getFitnessScore()) / sum; + if (normalized_distance > tos) { + selected_elements.add(solution); + solutions.remove(solution); + break; + } + } + } + return selected_elements; + } + /** * Select elements for the next generation. * Two elements are randomly selected. The element with the best diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java index f444f8a..bce9edb 100644 --- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java +++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java @@ -3,10 +3,7 @@ package be.cylab.java.wowa.training; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; +import java.util.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -145,6 +142,59 @@ class TrainerTest { } + @Test + void testGenerateQuasiRandomInitialPopulationFewForcedValues() { + int population_size = 30; + int number_of_random_elements = 15; + int number_of_not_random_elements = 15; + int number_of_weights = 5; + int counter_random = 0; + int counter_non_random = 0; + boolean[][] memory = new boolean[number_of_weights][number_of_weights]; + int memory_counter = 0; + List<AbstractSolution> population = this.trainer.generateQuasiRandomInitialPopulation(number_of_weights, population_size); + assertEquals(population_size, population.size()); + for(AbstractSolution solution : population) { + assertEquals( 1.0, Utils.sumListElements(solution.getWeightsW()), 0.00001); + assertEquals( 1.0, Utils.sumListElements(solution.getWeightsP()), 0.00001); + assertEquals(5, solution.getWeightsW().size()); + assertEquals(5, solution.getWeightsP().size()); + assertEquals(Double.POSITIVE_INFINITY, solution.getFitnessScore()); + if (solution.getWeightsW().get(0) != 0.0 && solution.getWeightsW().get(0) != 1.0) { + counter_non_random++; + } + + } + for (int i = 0; i < number_of_weights; i++) { + for (int j = 0; j < number_of_weights; j++) { + List<Double> w = Utils.initializeListWithZeroValues(number_of_weights); + List<Double> p = Utils.initializeListWithZeroValues(number_of_weights); + w.set(i, 1.0); + p.set(j, 1.0); + AbstractSolution sol = new SolutionDistance(number_of_weights); + sol.setWeightsW(w); + sol.setWeightsP(p); + for (AbstractSolution solution : population) { + if (solution.getWeightsW().equals(sol.getWeightsW()) && solution.getWeightsP().equals(sol.getWeightsP())) { + counter_random++; + memory[i][j] = !memory[i][j]; + } + } + } + } + assertEquals(number_of_random_elements, counter_random); + assertEquals(number_of_not_random_elements, counter_non_random); + for (boolean[] line : memory) { + for (boolean el : line) { + if (el) { + memory_counter++; + } + } + } + assertEquals(number_of_random_elements, memory_counter); + + } + @Test void testComputeDistances() { List<AbstractSolution> population = new ArrayList<>(); @@ -183,6 +233,64 @@ class TrainerTest { } + /** + * The test function uses a modified version of rouletteWheelSelection + * because of the random feature inside the function. + * The test method uses a seed to force a specific "random" generation. + */ + @Test + void testRouletteWheelSelection() { + int number_of_weights = 5; + int population_size = 50; + int number_of_selected_elements = 10; + List<List<Double>> data = generateData(500, number_of_weights); + List<Double> expected = generateExpected(500); + Random rnd = new Random(55646); + List<AbstractSolution> population = new ArrayList<>(); + for (int i = 0; i < population_size; i++){ + population.add(new SolutionDistance(number_of_weights, rnd.nextInt())); + } + trainer.computeDistances(population, data, expected); + Collections.sort(population); + List<AbstractSolution> selected = new ArrayList<>(); + selected = trainer.rouletteWheelSelectionForTestOnly(population, selected, number_of_selected_elements, rnd.nextInt()); + assertEquals(number_of_selected_elements, selected.size()); + assertEquals(population_size - number_of_selected_elements, population.size()); + double[] assertion = {7.622288005169746, 7.272594799246598, 7.147538725780932, 7.347211225721799, 7.554008720137534, + 6.977912255816629, 7.553481886293156, 7.655488416261908, 7.255009528708453, 7.68296999390714}; + for (int i = 0; i < selected.size(); i++) { + assertEquals(assertion[i], selected.get(i).getFitnessScore()); + } + } + + /** + * Test not perfect. It is not possible to test the selected elements because of the random feature of + * tournamentSelection. + */ + @Test + void testTournamentSelection() { + int number_of_weights = 5; + int population_size = 50; + int number_of_selected_elements = 10; + List<List<Double>> data = generateData(500, number_of_weights); + List<Double> expected = generateExpected(500); + Random rnd = new Random(55646); + List<AbstractSolution> population = new ArrayList<>(); + for (int i = 0; i < population_size; i++){ + population.add(new SolutionDistance(number_of_weights, rnd.nextInt())); + } + trainer.computeDistances(population, data, expected); + Collections.sort(population); + List<AbstractSolution> selected = new ArrayList<>(); + selected = trainer.tournamentSelection(population, selected, number_of_selected_elements); + assertEquals(number_of_selected_elements, selected.size()); + assertEquals(population_size - number_of_selected_elements, population.size()); + } + + /** + * This method tests if the two best elements of the population are correctly selected for the next generation. + * The method SelectParents uses rouletteWheelSelection or tournamentSelection that are test before. + */ @Test void testSelectParents() { List<AbstractSolution> population = this.trainer.generateInitialPopulation(5, 100); @@ -202,12 +310,11 @@ class TrainerTest { } + /** + * This method test reproduce AND crossoverElements methods. + */ @Test - void doReproduction() { - } - - @Test - void reproduce() { + void testReproduce() { List<AbstractSolution> population = new ArrayList<>(); for (int i = 0; i < this.trainer.getParameters().getPopulationSize(); i++) { AbstractSolution solution = new SolutionDistance(5, 2*i); @@ -242,28 +349,62 @@ class TrainerTest { } + /** + * To perform this test we perform 1000 times a mutation (rate = 10) on a + * specific population (50 elements) and count for each the number of solution mutated. + * We perform a mean of these counts and check if the result is between 4 and 5. + */ @Test - void randomlyMutateGenes() { - List<AbstractSolution> population = trainer.generateInitialPopulation(5, 100); - List<AbstractSolution> cloned_list = new ArrayList<>(); - cloned_list.addAll(population); - List<Double> w = new ArrayList<>(); - w.add(0.2); - w.add(0.2); - w.add(0.2); - w.add(0.2); - w.add(0.2); - population.get(0).setWeightsP(w); - int cnt = 0; - List<AbstractSolution> mutated_population = trainer.randomlyMutateGenes(population); - for (int i = 0; i < mutated_population.size(); i++) { - System.out.println(population.get(i)); - System.out.println(cloned_list.get(i)); + void testRandomlyMutateGenes() { + int general_cnt = 0; + for (int j = 0; j < 1000; j++) { + int number_of_weights = 5; + int population_size = 50; + int cnt = 0; + Random rnd = new Random(55646); + List<AbstractSolution> population = new ArrayList<>(); + for (int i = 0; i < population_size; i++){ + population.add(new SolutionDistance(number_of_weights, rnd.nextInt())); + } + List<AbstractSolution> population_copy = new ArrayList<>(); + rnd = new Random(55646); + for (int i = 0; i < population_size; i++){ + population_copy.add(new SolutionDistance(number_of_weights, rnd.nextInt())); + } + population = trainer.randomlyMutateGenes(population); + assertEquals(population_copy.size(), population.size()); + for (int i = 0; i < population_size; i++) { + if (!population.get(i).equals(population_copy.get(i))) { + cnt++; + } + } + general_cnt = general_cnt + cnt; } - System.out.println(cnt); + assertTrue(general_cnt / (double)1000 > 4.0); + assertTrue(general_cnt / (double)1000 < 5.0); + } + /** + * This method doesn't require a specific test. + * Indeed, this method calls generateInitialPopulation or + * generateQuasiRandomInitialPopulation method followed by the method + * computeDistances. + * All these methods are tested individually before + */ + @Test + void testGenerateInitialPopulationAndComputeDistances() { + } + /** + * This method doesn't require a specific test. + * Indeed, this method calls selectParents, then doReproduction and finally + * randomlyMutatedGenes. + * All these lethods are tested individually + */ + @Test + void testPerformReproduction() { + } static List<List<Double>> generateData(final int size, final int weight_number) { Random rnd = new Random(5489); List<List<Double>> data = new ArrayList<>(); -- GitLab