From 4a2b5ef78f9ec7802a0aef2d7c9c504c87ea8499 Mon Sep 17 00:00:00 2001 From: "a.croix" <croix.alexandre@gmail.com> Date: Mon, 18 Feb 2019 16:21:23 +0100 Subject: [PATCH] First version. Need tests --- .../be/cylab/java/wowa/training/Example.java | 6 + .../java/wowa/training/SolutionDistance.java | 16 +- .../be/cylab/java/wowa/training/Trainer.java | 183 ++++++++++++++++-- .../be/cylab/java/wowa/training/Utils.java | 7 +- 4 files changed, 191 insertions(+), 21 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index a7523b2..f2351b8 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -2,6 +2,7 @@ package be.cylab.java.wowa.training; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.logging.Logger; @@ -29,6 +30,11 @@ public class Example { for (SolutionDistance solution : population) { System.out.println(solution); } + System.out.println("Trié\n"); + Collections.sort(population); + for (SolutionDistance solution : population) { + System.out.println(solution); + } } diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java index ac03857..3b375c3 100644 --- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java +++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java @@ -3,9 +3,10 @@ package be.cylab.java.wowa.training; //import info.debatty.java.aggregation.WOWA; import java.util.Arrays; +import java.util.Collections; import java.util.List; -public class SolutionDistance { +public class SolutionDistance implements Comparable<SolutionDistance> { public double[] weights_w; public double[] weights_p; @@ -53,14 +54,17 @@ public class SolutionDistance { } + @Override + public int compareTo(SolutionDistance solution) { + return this.getDistance() > solution.getDistance() ? -1 : this.getDistance() < solution.getDistance() ? 1 : 0; + } + + public void normalize() { this.weights_w = Utils.normalizeWeights(this.weights_w); this.weights_p = Utils.normalizeWeights(this.weights_p); } - public static void sort(List<SolutionDistance> solutions) { - - } public double[] getWeights_w() { return weights_w; @@ -73,4 +77,8 @@ public class SolutionDistance { public double getDistance() { return distance; } + + public void setDistance(double distance) { + this.distance = distance; + } } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index cf287dd..2e7e2dd 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -1,6 +1,7 @@ package be.cylab.java.wowa.training; import java.util.ArrayList; +import java.util.Collections; import java.util.List; public class Trainer { @@ -10,15 +11,43 @@ public class Trainer { public Trainer(TrainerParameters parameters) { this.parameters = parameters; } -/* + public SolutionDistance run(List<double[]> data, double[] expected) { + List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected); + + Collections.sort(current_population); + SolutionDistance best_solution = this.findBestSolution(current_population); + for (int generation = 0; generation < this.parameters.getMaxGenerationNumber(); generation++) { + current_population = this.performReproduction(current_population, data, expected); + SolutionDistance best_solution_of_current_population = this.findBestSolution(current_population); + + if (this.getParameters().getLogger() != null) { + + } + + //We found a new best solution + if (best_solution_of_current_population.getDistance() < best_solution.getDistance()) { + best_solution = best_solution_of_current_population; + } + if (Math.abs(best_solution.getDistance()) < this.getParameters().getTriggerDistance()) { + break; + } + } + return best_solution; } - public SolutionDistance findBestSolution(SolutionDistance[] solutions) { + public SolutionDistance findBestSolution(List<SolutionDistance> solutions) { + SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length); + for (SolutionDistance solution : solutions) { + if (solution.getDistance() < best_solution.getDistance()) { + best_solution = solution; + } + } + return best_solution; } -*/ + public List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) { List<SolutionDistance> population = new ArrayList<SolutionDistance>(); for (int i = 0; i < populationSize; i++) { @@ -33,9 +62,6 @@ public class Trainer { return population; } - public double getRandomDouble() { - return 0; - } public List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) { for (SolutionDistance solution : solutions) { @@ -44,31 +70,156 @@ public class Trainer { return solutions; } -/* - private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, int count) { + private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { + double max = Utils.findMaxDistance(solutions); + double min = Utils.findMinDistance(solutions); + double sum = Utils.sumTotalDistance(solutions); + + + if (solutions.size() < count) { + throw new IllegalArgumentException("Not enough elements in population to select "+ count + "parents"); + } + while (selected_elements.size() < count) { + double tos = Math.random(); + double normalized_distance = 0; + for (SolutionDistance solution : solutions) { + normalized_distance = normalized_distance + (max + min - solution.getDistance()) / sum; + if (normalized_distance > tos) { + selected_elements.add(solution); + solutions.remove(solution); + break; + } + } + } + return selected_elements; } - private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, int count){ + private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count){ + + while (selected_elements.size() < count) { + Collections.sort(solutions); + + //Select random element in the population + int solution1_index = Utils.randomInteger(0, solutions.size()); + SolutionDistance solution1 = solutions.get(solution1_index); + //Select random element in the population + int solution2_index = Utils.randomInteger(0, solutions.size()); + SolutionDistance solution2 = solutions.get(solution2_index); + + //Compare the two selected element and put the solution with the smallest distance in selected list + if (solution1.getDistance() < solution2.getDistance()) { + selected_elements.add(solution1); + solutions.remove(solution1); + } else { + selected_elements.add(solution2); + solutions.remove(solution2); + + } + } + return selected_elements; } public List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) { + List<SolutionDistance> selected_parents = new ArrayList<>(); + //Select the two best current solutions + + //Sort population + Collections.sort(solutions); + //Put the two best elements in selected parents list + selected_parents.add(solutions.get(0)); + selected_parents.add(solutions.get(1)); + //Remove the two best elements from the list solutions + solutions.remove(0); + solutions.remove(1); + + if(selectionMethod == TrainerParameters.SELECTION_METHOD_RWS) { + + return this.rouletteWheelSelection(solutions, selected_parents, count - 2); + } + else if (selectionMethod == TrainerParameters.SELECTION_METHOD_TOS) { + return this.tournamentSelection(solutions, selected_parents, count - 2); + } + throw new IllegalArgumentException("Invalid selection method"); } public List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) { + int nbr_weights = solutions.get(0).weights_p.length; + + //we add children to the current list of solutions + while (solutions.size() < this.parameters.getPopulationSize()) { + int dad = Utils.randomInteger(0, solutions.size() - 1); + int mom; + do { + mom = Utils.randomInteger(0, solutions.size() - 1); + } while (dad == mom); + int cut_position = Utils.randomInteger(0, nbr_weights - 1); + + //beta is used to compute the new value at the cut position + double beta = Math.random(); + solutions = this.reproduce(solutions.get(dad), solutions.get(mom), solutions, cut_position, beta); + } + while(solutions.size() > this.getParameters().getPopulationSize()) { + solutions.remove(solutions.get(solutions.size() - 1)); + } + return solutions; } - public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, int cutPosition, float beta) { - //Different compare to PHP implementation !! PHP allows to return several elements ! + public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) { + double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); + double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); + + double pnew1P = dad.weights_p[cutPosition] - beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); + double pnew2P = mom.weights_p[cutPosition] + beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); + + SolutionDistance child1 = new SolutionDistance(dad.weights_p.length); + SolutionDistance child2 = new SolutionDistance(dad.weights_p.length); + + for (int i = 0; i < cutPosition; i++) { + child1.weights_w[i] = dad.weights_w[i]; + child1.weights_p[i] = dad.weights_p[i]; + + child2.weights_w[i] = mom.weights_w[i]; + child2.weights_p[i] = mom.weights_p[i]; + } + + child1.weights_w[cutPosition] = pnew1W; + child2.weights_w[cutPosition] = pnew2W; + child1.weights_p[cutPosition] = pnew1P; + child2.weights_p[cutPosition] = pnew2P; + + int nbre_weights = dad.weights_w.length; + for (int i = cutPosition + 1; i < nbre_weights; i++) { + child1.weights_w[i] = mom.weights_w[i]; + child1.weights_p[i] = mom.weights_p[i]; + + child2.weights_w[i] = dad.weights_w[i]; + child2.weights_p[i] = dad.weights_p[i]; + } + //Check and correct only if we used a QuasiRandom generation + child1.normalize(); + child2.normalize(); + solutions.add(child1); + solutions.add(child2); + + return solutions; } public List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) { + double probability = this.getParameters().getMutationRate() / 100; + for (int i = 0; i < solutions.size(); i++) { + if (i > 1) { + solutions.get(i).randomlyMutateWithProbability(probability); + } + + } + return solutions; } -*/ + public TrainerParameters getParameters() { return this.parameters; } @@ -78,11 +229,15 @@ public class Trainer { List<SolutionDistance> initial_population = this.generateInitialPopulation(number_of_weights, populationSize); return this.computeDistances(initial_population, data, expected); } -/* + public List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) { + List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod()); + List<SolutionDistance> new_generation = this.doReproduction(parents); + List<SolutionDistance> mutated = this.randomlyMutateGenes(new_generation); + return this.computeDistances(mutated, data, expected); } -*/ + private void checkAndCorrectNullWeightVector(SolutionDistance child) { } diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 0f31e02..680f633 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -2,6 +2,7 @@ package be.cylab.java.wowa.training; import java.lang.reflect.Array; import java.util.Arrays; +import java.util.List; import java.util.Random; public class Utils { @@ -15,7 +16,7 @@ public class Utils { return weightsNormalized; } - public double findMaxDistance(SolutionDistance[] solutions) { + public static double findMaxDistance(List<SolutionDistance> solutions) { double max = Double.NEGATIVE_INFINITY; for (SolutionDistance solution : solutions) { if(solution.getDistance() > max) { @@ -25,7 +26,7 @@ public class Utils { return max; } - public double findMinDistance(SolutionDistance[] solutions) { + public static double findMinDistance(List<SolutionDistance> solutions) { double min = Double.POSITIVE_INFINITY; for (SolutionDistance solution : solutions) { if (solution.getDistance() < min){ @@ -35,7 +36,7 @@ public class Utils { return min; } - public double sumTotalDistance(SolutionDistance[] solutions) { + public static double sumTotalDistance(List<SolutionDistance> solutions) { double sum = 0; for (SolutionDistance solution : solutions) { sum += solution.getDistance(); -- GitLab