From 523ecf358b9a82ae08e1403b1c533e6844d64006 Mon Sep 17 00:00:00 2001 From: "a.croix" <croix.alexandre@gmail.com> Date: Tue, 19 Feb 2019 11:15:34 +0100 Subject: [PATCH] Fix some style --- .../be/cylab/java/wowa/training/Trainer.java | 158 +++++++++++++----- .../java/wowa/training/TrainerParameters.java | 10 +- 2 files changed, 124 insertions(+), 44 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 961282e..2fae00f 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -18,21 +18,35 @@ public class Trainer { * @param expected * @return */ - public SolutionDistance run(final List<double[]> data, final double[] expected) { - List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected); + public SolutionDistance run(final List<double[]> data, + final double[] expected + ) { + List<SolutionDistance> current_population = + this.generateInitialPopulationAndComputeDistances( + this.getParameters().getPopulationSize(), + data, + expected + ); Collections.sort(current_population); - SolutionDistance best_solution = this.findBestSolution(current_population); - for (int generation = 0; generation < this.parameters.getMaxGenerationNumber(); generation++) { - current_population = this.performReproduction(current_population, data, expected); - SolutionDistance best_solution_of_current_population = this.findBestSolution(current_population); + SolutionDistance best_solution = + this.findBestSolution(current_population); + for (int generation = 0; + generation < this.parameters.getMaxGenerationNumber(); + generation++ + ) { + current_population = this.performReproduction( + current_population, data, expected); + SolutionDistance best_solution_of_current_population = + this.findBestSolution(current_population); if (this.getParameters().getLogger() != null) { } //We found a new best solution - if (best_solution_of_current_population.getDistance() < best_solution.getDistance()) { + if (best_solution_of_current_population.getDistance() < + best_solution.getDistance()) { best_solution = best_solution_of_current_population; } @@ -44,8 +58,11 @@ public class Trainer { * @param solutions * @return */ - private SolutionDistance findBestSolution(final List<SolutionDistance> solutions) { - SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length); + private SolutionDistance findBestSolution( + final List<SolutionDistance> solutions + ) { + SolutionDistance best_solution = + new SolutionDistance(solutions.get(0).weights_p.length); for (SolutionDistance solution : solutions) { if (solution.getDistance() < best_solution.getDistance()) { @@ -60,7 +77,10 @@ public class Trainer { * @param populationSize * @return */ - private List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) { + private List<SolutionDistance> generateInitialPopulation( + int numberOfWeights, + int populationSize + ) { List<SolutionDistance> population = new ArrayList<>(); for (int i = 0; i < populationSize; i++) { SolutionDistance solution = new SolutionDistance(numberOfWeights); @@ -80,7 +100,11 @@ public class Trainer { * @param expected * @return */ - private List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) { + private List<SolutionDistance> computeDistances( + List<SolutionDistance> solutions, + List<double[]> data, + double[] expected + ) { for (SolutionDistance solution : solutions) { solution.computeScoreTo(data, expected); } @@ -94,20 +118,28 @@ public class Trainer { * @param count * @return */ - private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { + private List<SolutionDistance> rouletteWheelSelection( + List<SolutionDistance> solutions, + List<SolutionDistance> selected_elements, + int count + ) { double max = Utils.findMaxDistance(solutions); double min = Utils.findMinDistance(solutions); double sum = Utils.sumTotalDistance(solutions); if (solutions.size() < count) { - throw new IllegalArgumentException("Not enough elements in population to select " + count + "parents"); + throw new IllegalArgumentException( + "Not enough elements in population to select " + + count + "parents" + ); } while (selected_elements.size() < count) { double tos = Math.random(); double normalized_distance = 0; for (SolutionDistance solution : solutions) { - normalized_distance = normalized_distance + (max + min - solution.getDistance()) / sum; + normalized_distance = normalized_distance + + (max + min - solution.getDistance()) / sum; if (normalized_distance > tos) { selected_elements.add(solution); solutions.remove(solution); @@ -124,7 +156,11 @@ public class Trainer { * @param count * @return */ - private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { + private List<SolutionDistance> tournamentSelection( + List<SolutionDistance> solutions, + List<SolutionDistance> selected_elements, + int count + ) { while (selected_elements.size() < count) { Collections.sort(solutions); @@ -137,7 +173,8 @@ public class Trainer { int solution2_index = Utils.randomInteger(0, solutions.size()); SolutionDistance solution2 = solutions.get(solution2_index); - //Compare the two selected element and put the solution with the smallest distance in selected list + //Compare the two selected element and put the solution with + // the smallest distance in selected list if (solution1.getDistance() < solution2.getDistance()) { selected_elements.add(solution1); solutions.remove(solution1); @@ -156,7 +193,11 @@ public class Trainer { * @param selectionMethod * @return */ - private List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) { + private List<SolutionDistance> selectParents( + List<SolutionDistance> solutions, + int count, + int selectionMethod + ) { List<SolutionDistance> selected_parents = new ArrayList<>(); //Select the two best current solutions @@ -171,9 +212,15 @@ public class Trainer { if (selectionMethod == TrainerParameters.SELECTION_METHOD_RWS) { - return this.rouletteWheelSelection(solutions, selected_parents, count - 2); + return this.rouletteWheelSelection(solutions, + selected_parents, + count - 2 + ); } else if (selectionMethod == TrainerParameters.SELECTION_METHOD_TOS) { - return this.tournamentSelection(solutions, selected_parents, count - 2); + return this.tournamentSelection(solutions, + selected_parents, + count - 2 + ); } throw new IllegalArgumentException("Invalid selection method"); @@ -183,7 +230,9 @@ public class Trainer { * @param solutions * @return */ - private List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) { + private List<SolutionDistance> doReproduction( + List<SolutionDistance> solutions + ) { int nbr_weights = solutions.get(0).weights_p.length; //we add children to the current list of solutions @@ -198,7 +247,12 @@ public class Trainer { //beta is used to compute the new value at the cut position double beta = Math.random(); - solutions = this.reproduce(solutions.get(dad), solutions.get(mom), solutions, cut_position, beta); + solutions = this.reproduce(solutions.get(dad), + solutions.get(mom), + solutions, + cut_position, + beta + ); } while (solutions.size() > this.getParameters().getPopulationSize()) { solutions.remove(solutions.get(solutions.size() - 1)); @@ -214,12 +268,21 @@ public class Trainer { * @param beta * @return */ - private List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) { - double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); - double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); - - double pnew1P = dad.weights_p[cutPosition] - beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); - double pnew2P = mom.weights_p[cutPosition] + beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); + private List<SolutionDistance> reproduce(SolutionDistance dad, + SolutionDistance mom, + List<SolutionDistance> solutions, + int cutPosition, + double beta + ) { + double p_new1W = dad.weights_w[cutPosition] - beta + * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); + double p_new2W = mom.weights_w[cutPosition] + beta + * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); + + double p_new1P = dad.weights_p[cutPosition] - beta + * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); + double p_new2P = mom.weights_p[cutPosition] + beta + * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]); SolutionDistance child1 = new SolutionDistance(dad.weights_p.length); SolutionDistance child2 = new SolutionDistance(dad.weights_p.length); @@ -232,10 +295,10 @@ public class Trainer { child2.weights_p[i] = mom.weights_p[i]; } - child1.weights_w[cutPosition] = pnew1W; - child2.weights_w[cutPosition] = pnew2W; - child1.weights_p[cutPosition] = pnew1P; - child2.weights_p[cutPosition] = pnew2P; + child1.weights_w[cutPosition] = p_new1W; + child2.weights_w[cutPosition] = p_new2W; + child1.weights_p[cutPosition] = p_new1P; + child2.weights_p[cutPosition] = p_new2P; int nbre_weights = dad.weights_w.length; for (int i = cutPosition + 1; i < nbre_weights; i++) { @@ -258,7 +321,9 @@ public class Trainer { * @param solutions * @return */ - private List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) { + private List<SolutionDistance> randomlyMutateGenes( + List<SolutionDistance> solutions + ) { double probability = this.getParameters().getMutationRate() / 100; for (int i = 0; i < solutions.size(); i++) { @@ -278,15 +343,23 @@ public class Trainer { } /** - * @param populationSize + * @param population_size * @param data * @param expected * @return */ - public List<SolutionDistance> generateInitialPopulationAndComputeDistances(int populationSize, List<double[]> data, double[] expected) { + public List<SolutionDistance> generateInitialPopulationAndComputeDistances( + int population_size, + List<double[]> data, + double[] expected + ) { int numberOfWeights = data.get(0).length; - List<SolutionDistance> initialPopulation = this.generateInitialPopulation(numberOfWeights, populationSize); - return this.computeDistances(initialPopulation, data, expected); + List<SolutionDistance> initial_population = + this.generateInitialPopulation( + numberOfWeights, + population_size + ); + return this.computeDistances(initial_population, data, expected); } /** @@ -295,10 +368,17 @@ public class Trainer { * @param expected * @return */ - private List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) { - List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod()); + private List<SolutionDistance> performReproduction( + List<SolutionDistance> population, + List<double[]> data, + double[] expected + ) { + List<SolutionDistance> parents = this.selectParents(population, + this.getParameters().getNumberParents(), + this.getParameters().getSelectionMethod()); List<SolutionDistance> new_generation = this.doReproduction(parents); - List<SolutionDistance> mutated = this.randomlyMutateGenes(new_generation); + List<SolutionDistance> mutated = + this.randomlyMutateGenes(new_generation); return this.computeDistances(mutated, data, expected); } diff --git a/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java b/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java index 3ea60c6..bceab73 100644 --- a/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java +++ b/src/main/java/be/cylab/java/wowa/training/TrainerParameters.java @@ -67,8 +67,8 @@ public class TrainerParameters { * @return int, number of parents. Depends on the crossover rate */ final public int getNumberParents() { - int nbr_parents = Math.round(this.getPopulationSize() * - (1 - this.getCrossoverRate() / 100)); + int nbr_parents = Math.round(this.getPopulationSize() + * (1 - this.getCrossoverRate() / 100)); if (nbr_parents % 2 == 1) { nbr_parents++; } @@ -87,7 +87,7 @@ public class TrainerParameters { * Getter for population_size. * @return int */ - public int getPopulationSize() { + final public int getPopulationSize() { return population_size; } @@ -108,7 +108,7 @@ public class TrainerParameters { * Getter for crossover rate. * @return int */ - private int getCrossoverRate() { + final private int getCrossoverRate() { return crossover_rate; } @@ -124,7 +124,7 @@ public class TrainerParameters { * Getter for mutation rate. * @return int */ - public int getMutationRate() { + final public int getMutationRate() { return mutation_rate; } -- GitLab