diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 29d7620949008d80370d79c5576414d03ef442c0..756dcd7ddd93b0396c7b815d0bbe310df1d37d93 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,5 +9,5 @@ image: maven:3.5.3-jdk-8 mvn:package: script: - - mvn clean verify + - mvn clean test diff --git a/pom.xml b/pom.xml index f7c824e08bb8693e9d73c91dc4c6b753b4c1a2b2..e62f97a467117b92fbb081807668ecdbbc20ded9 100644 --- a/pom.xml +++ b/pom.xml @@ -53,7 +53,10 @@ </execution> </executions> </plugin> - + <plugin> + <artifactId>maven-surefire-plugin</artifactId> + <version>2.22.1</version> + </plugin> </plugins> </build> <properties> diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index 4a3dea8a83d383092b281fc3502e1f9ea5642cc2..47527107e45e6d0fdb10b9030179ddc57caf6e15 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -14,16 +14,16 @@ public class Example { public static void main(String[] args) { Logger LOGGER = Logger.getLogger(Trainer.class.getName()); LOGGER.setLevel(Level.INFO); - TrainerParameters parameters = new TrainerParameters(LOGGER, 10, 60, 10, TrainerParameters.SELECTION_METHOD_RWS, 100); + TrainerParameters parameters = new TrainerParameters(LOGGER, 200, 60, 10, TrainerParameters.SELECTION_METHOD_RWS, 100); Trainer trainer = new Trainer(parameters); List<double[]> data = new ArrayList<>(); - for (int i = 0; i < 5; i++) { + for (int i = 0; i < 1500; i++) { double[] element = {Math.random(), Math.random(), Math.random()}; data.add(element); } - double[] expected = new double[5]; + double[] expected = new double[1500]; - for (int j = 0; j < 5; j++) { + for (int j = 0; j < 1500; j++) { expected[j] = Math.random(); } diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java index f2b9217202425cf5721f50d676a6320a621b26ff..54a3cf85a91e0d5dec3a9787f219513ccf8902f1 100644 --- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java +++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java @@ -112,9 +112,9 @@ public class SolutionDistance implements Comparable<SolutionDistance>, Cloneable @Override public int compareTo(final SolutionDistance solution) { if (this.getDistance() > solution.getDistance()) { - return -1; - } else if (this.getDistance() < solution.getDistance()) { return 1; + } else if (this.getDistance() < solution.getDistance()) { + return -1; } else { return 0; } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 4de5c75e204edbbb5466419409d5d4a3006648c1..13b0783fa45cf2ab6aa853bd15895650ed7d85aa 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -51,7 +51,7 @@ public class Trainer { this.findBestSolution(current_population); if (this.getParameters().getLogger() != null) { - this.getParameters().getLogger().log(Level.WARNING, best_solution.toString()); + this.getParameters().getLogger().log(Level.INFO, "Generation " + (generation+1) + " : " + best_solution.toString()); } //We found a new best solution @@ -68,11 +68,10 @@ public class Trainer { * @param solutions * @return */ - private SolutionDistance findBestSolution( + SolutionDistance findBestSolution( final List<SolutionDistance> solutions ) { - SolutionDistance best_solution = - new SolutionDistance(solutions.get(0).getWeightsP().length); + SolutionDistance best_solution = solutions.get(0); for (SolutionDistance solution : solutions) { if (solution.getDistance() < best_solution.getDistance()) { @@ -87,7 +86,7 @@ public class Trainer { * @param population_size * @return */ - private List<SolutionDistance> generateInitialPopulation( + List<SolutionDistance> generateInitialPopulation( final int number_of_weights, final int population_size ) { @@ -105,7 +104,7 @@ public class Trainer { * @param expected * @return */ - private List<SolutionDistance> computeDistances( + List<SolutionDistance> computeDistances( final List<SolutionDistance> solutions, final List<double[]> data, final double[] expected @@ -123,7 +122,7 @@ public class Trainer { * @param count * @return */ - private List<SolutionDistance> rouletteWheelSelection( + List<SolutionDistance> rouletteWheelSelection( final List<SolutionDistance> solutions, final List<SolutionDistance> selected_elements, final int count @@ -161,7 +160,7 @@ public class Trainer { * @param count * @return */ - private List<SolutionDistance> tournamentSelection( + List<SolutionDistance> tournamentSelection( final List<SolutionDistance> solutions, final List<SolutionDistance> selected_elements, final int count @@ -171,11 +170,11 @@ public class Trainer { Collections.sort(solutions); //Select random element in the population - int solution1_index = Utils.randomInteger(0, solutions.size()); + int solution1_index = Utils.randomInteger(0, solutions.size() - 1); SolutionDistance solution1 = solutions.get(solution1_index); //Select random element in the population - int solution2_index = Utils.randomInteger(0, solutions.size()); + int solution2_index = Utils.randomInteger(0, solutions.size() - 1); SolutionDistance solution2 = solutions.get(solution2_index); //Compare the two selected element and put the solution with @@ -198,7 +197,7 @@ public class Trainer { * @param selection_method * @return */ - private List<SolutionDistance> selectParents( + List<SolutionDistance> selectParents( final List<SolutionDistance> solutions, final int count, final int selection_method @@ -219,12 +218,12 @@ public class Trainer { return this.rouletteWheelSelection(solutions, selected_parents, - count - 2 + count ); } else if (selection_method == TrainerParameters.SELECTION_METHOD_TOS) { return this.tournamentSelection(solutions, selected_parents, - count - 2 + count ); } throw new IllegalArgumentException("Invalid selection method"); @@ -235,7 +234,7 @@ public class Trainer { * @param solutions * @return */ - private List<SolutionDistance> doReproduction( + List<SolutionDistance> doReproduction( final List<SolutionDistance> solutions ) { int nbr_weights = solutions.get(0).getWeightsP().length; @@ -273,7 +272,7 @@ public class Trainer { * @param beta * @return */ - private void reproduce(final SolutionDistance dad, + void reproduce(final SolutionDistance dad, final SolutionDistance mom, final List<SolutionDistance> solutions, final int cut_position, @@ -331,7 +330,7 @@ public class Trainer { * @param solutions * @return */ - private List<SolutionDistance> randomlyMutateGenes( + List<SolutionDistance> randomlyMutateGenes( final List<SolutionDistance> solutions ) { double probability = this.getParameters().getMutationRate() / 100; @@ -348,7 +347,7 @@ public class Trainer { /** * @return */ - private TrainerParameters getParameters() { + TrainerParameters getParameters() { return this.parameters; } @@ -378,7 +377,7 @@ public class Trainer { * @param expected * @return */ - private List<SolutionDistance> performReproduction( + List<SolutionDistance> performReproduction( final List<SolutionDistance> population, final List<double[]> data, final double[] expected diff --git a/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java b/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java index 5e7c45ef8b110a3304cc997a0c21e60e1b46471b..8d5a196d3f195d32c2f1bb4835efff79ddeb8ae8 100644 --- a/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java +++ b/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java @@ -68,14 +68,14 @@ public class TrainerParameters { * Getter for number of parents. * @return int */ - final int getNumberOfParents() { + int getNumberOfParents() { return this.number_of_parents; } /** * @return Logger */ - final Logger getLogger() { + Logger getLogger() { return logger; } @@ -84,8 +84,8 @@ public class TrainerParameters { * Getter for population_size. * @return int */ - final int getPopulationSize() { - return population_size; + int getPopulationSize() { + return this.population_size; } /** @@ -110,7 +110,7 @@ public class TrainerParameters { private void setCrossoverRate(final int crossover_rate) { this.crossover_rate = crossover_rate; int nbr_parents = Math.round(this.population_size - * (1 - crossover_rate / 100)); + * (1 - (float)crossover_rate / 100)); if (nbr_parents % 2 == 1) { nbr_parents++; } diff --git a/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java b/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java index e6f474d097168e0fb7c2fa87b8dec3dd99389595..c81922fbc52ef4f7a1f393afab0a90371361b4a1 100644 --- a/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java +++ b/src/test/java/be/cylab/java/wowa/training/SolutionDistanceTest.java @@ -8,10 +8,10 @@ import java.util.Random; import static org.junit.jupiter.api.Assertions.*; -class SolutionDistanceTest { +class SolutionDistanceTest { - @Test - void generateSolutionDistance() { + @org.junit.jupiter.api.Test + public void testGenerateSolutionDistance() { SolutionDistance solution = new SolutionDistance(5, 1234); double actualW = Utils.sumArrayElements(solution.getWeightsW()); double actualP = Utils.sumArrayElements(solution.getWeightsP()); @@ -21,8 +21,8 @@ class SolutionDistanceTest { assertEquals(5, solution.getWeightsP().length); } - @Test - void computeScoreTo() { + @org.junit.jupiter.api.Test + public void computeScoreTo() { SolutionDistance solution = new SolutionDistance(5,1234); List<double[]> data = generateData(20, 5); double[] expected = generateExpected(20); diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java new file mode 100644 index 0000000000000000000000000000000000000000..db5421ce6e4cdca5f2375aa6284acac2f3d79bc5 --- /dev/null +++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java @@ -0,0 +1,134 @@ +package be.cylab.java.wowa.training; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.logging.Logger; + +import static org.junit.jupiter.api.Assertions.*; + +class TrainerTest { + + private Trainer trainer; + @BeforeEach + void setUp() { + + TrainerParameters parameters = new TrainerParameters(null, 100, 30, 10, TrainerParameters.SELECTION_METHOD_RWS, 100); + this.trainer = new Trainer(parameters); + } + + @Test + void testRun() { + } + + @Test + void testFindBestSolution() { + List<SolutionDistance> population = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + SolutionDistance solution = new SolutionDistance(5, i); + population.add(solution); + } + List<double[]> data = generateData(100, 5); + double[] expected = generateExpected(100); + List<SolutionDistance> computed_population = this.trainer.computeDistances(population, data, expected); + SolutionDistance bestSolution = this.trainer.findBestSolution(computed_population); + assertEquals(3.045715472368455, bestSolution.getDistance()); + } + + @Test + void testGenerateInitialPopulation() { + List<SolutionDistance> population = this.trainer.generateInitialPopulation(5, 100); + assertEquals(100, population.size()); + for (SolutionDistance solution : population) { + assertEquals(5, solution.getWeightsW().length); + assertEquals(5, solution.getWeightsP().length); + assertEquals(Double.POSITIVE_INFINITY, solution.getDistance()); + } + } + + @Test + void testComputeDistances() { + + } + + @Test + void testSelectParents() { + List<SolutionDistance> population = this.trainer.generateInitialPopulation(5, 100); + List<double[]> data = generateData(100, 5); + double[] expected = generateExpected(100); + List<SolutionDistance> computed_population = this.trainer.computeDistances(population, data, expected); + SolutionDistance best_solution = this.trainer.findBestSolution(computed_population); + List<SolutionDistance> parents = this.trainer.selectParents(computed_population, 30, TrainerParameters.SELECTION_METHOD_TOS); + assertEquals(best_solution, parents.get(0)); + assertEquals(30, parents.size()); + + computed_population = this.trainer.computeDistances(population, data, expected); + best_solution = this.trainer.findBestSolution(computed_population); + parents = this.trainer.selectParents(computed_population, 30, TrainerParameters.SELECTION_METHOD_RWS); + assertEquals(best_solution, parents.get(0)); + assertEquals(30, parents.size()); + + } + + @Test + void doReproduction() { + } + + @Test + void reproduce() { + List<SolutionDistance> population = new ArrayList<>(); + for (int i = 0; i < this.trainer.getParameters().getPopulationSize(); i++) { + SolutionDistance solution = new SolutionDistance(5, 2*i); + population.add(solution); + } + List<double[]> data = generateData(100, 5); + double[] expected = generateExpected(100); + population = this.trainer.computeDistances(population, data, expected); + List<SolutionDistance> parents = this.trainer.selectParents(population, 10, TrainerParameters.SELECTION_METHOD_RWS); + this.trainer.reproduce(parents.get(0), parents.get(1), parents, 1, 0.25875); + double[] expected_weight_w_child_1 = new double[] {0.1987717620825617, 0.154039554, 0.26261973461084054, 0.14551847852505223, 0.18295360349083772}; + double[] expected_weight_w_child_2 = new double[] {0.28709019832051996, 0.133065638, 0.2628906777282328, 0.15925182144723837, 0.2137985176214255}; + double[] expected_weight_p_child_1 = new double[] {0.3119444801528709, 0.130874978, 0.3091280121036283, 0.0943591649660173, 0.26790380888833576}; + double[] expected_weight_p_child_2 = new double[] {0.23379455171152877, 0.107402174, 0.0550712506540036, 0.22911863245792205, 0.2604029799761541}; + expected_weight_w_child_1 = Utils.normalizeWeights(expected_weight_w_child_1); + expected_weight_w_child_2 = Utils.normalizeWeights(expected_weight_w_child_2); + expected_weight_p_child_1 = Utils.normalizeWeights(expected_weight_p_child_1); + expected_weight_p_child_2 = Utils.normalizeWeights(expected_weight_p_child_2); + assertArrayEquals(parents.get(10).getWeightsW(), expected_weight_w_child_1, 0.0001); + assertArrayEquals(parents.get(10).getWeightsP(), expected_weight_p_child_1, 0.0001); + assertArrayEquals(parents.get(11).getWeightsW(), expected_weight_w_child_2, 0.0001); + assertArrayEquals(parents.get(11).getWeightsP(), expected_weight_p_child_2, 0.0001); + + } + + @Test + void randomlyMutateGenes() { + } + + + static List<double[]> generateData(final int size, final int weight_number) { + Random rnd = new Random(5489); + List<double[]> data = new ArrayList<>(); + for (int i = 0; i < size; i++) { + double[] vector = new double[weight_number]; + for (int j = 0; j < weight_number; j++) { + vector[j] = rnd.nextDouble(); + } + data.add(vector); + } + return data; + } + + static double[] generateExpected(final int size) { + Random rnd = new Random(5768); + double[] expected = new double[size]; + for (int i = 0; i < size; i++) { + expected[i] = rnd.nextDouble(); + } + return expected; + } +} \ No newline at end of file