From 4e2160f8440b351d7b1723bc5f5d060f777d81ae Mon Sep 17 00:00:00 2001 From: "a.croix" <croix.alexandre@gmail.com> Date: Tue, 19 Feb 2019 09:29:47 +0100 Subject: [PATCH] Add basic Javadoc --- java-wowa-training.iml | 9 ++ .../java/wowa/training/SolutionDistance.java | 28 +++++- .../be/cylab/java/wowa/training/Trainer.java | 91 ++++++++++++++++++- .../be/cylab/java/wowa/training/Utils.java | 32 ++++++- 4 files changed, 155 insertions(+), 5 deletions(-) diff --git a/java-wowa-training.iml b/java-wowa-training.iml index db1671e..b26eb4a 100644 --- a/java-wowa-training.iml +++ b/java-wowa-training.iml @@ -22,5 +22,14 @@ <SOURCES /> </library> </orderEntry> + <orderEntry type="module-library"> + <library> + <CLASSES> + <root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/annotations/16.0.2/annotations-16.0.2.jar!/" /> + </CLASSES> + <JAVADOC /> + <SOURCES /> + </library> + </orderEntry> </component> </module> \ No newline at end of file diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java index dd10c6d..a6006d7 100644 --- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java +++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java @@ -18,6 +18,10 @@ public class SolutionDistance implements Comparable<SolutionDistance> { this.weights_p = new double[weights_number]; } + /** + * + * @return + */ @Override public String toString() { return "SolutionDistance{" + @@ -27,13 +31,21 @@ public class SolutionDistance implements Comparable<SolutionDistance> { + '}'; } + /** + * + * @param data + * @param expected + */ public void computeScoreTo(List<double[]> data, double[] expected) { this.distance = Math.random(); } - + /** + * + * @param probability + */ public void randomlyMutateWithProbability(final double probability) { double tos = Math.random(); if (tos > probability) { @@ -55,18 +67,28 @@ public class SolutionDistance implements Comparable<SolutionDistance> { } + /** + * + * @param solution + * @return + */ @Override public int compareTo(final SolutionDistance solution) { return this.getDistance() > solution.getDistance() ? -1 : this.getDistance() < solution.getDistance() ? 1 : 0; } - + /** + * + */ public void normalize() { this.weights_w = Utils.normalizeWeights(this.weights_w); this.weights_p = Utils.normalizeWeights(this.weights_p); } - + /** + * + * @return + */ public double getDistance() { return distance; } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index bbcd0b5..a0324be 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -1,5 +1,7 @@ package be.cylab.java.wowa.training; +import org.jetbrains.annotations.Contract; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -12,6 +14,12 @@ public class Trainer { this.parameters = parameters; } + /** + * + * @param data + * @param expected + * @return + */ public SolutionDistance run(final List<double[]> data, final double[] expected) { List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected); @@ -37,6 +45,11 @@ public class Trainer { return best_solution; } + /** + * + * @param solutions + * @return + */ public SolutionDistance findBestSolution(final List<SolutionDistance> solutions) { SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length); @@ -48,6 +61,12 @@ public class Trainer { return best_solution; } + /** + * + * @param numberOfWeights + * @param populationSize + * @return + */ public List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) { List<SolutionDistance> population = new ArrayList<SolutionDistance>(); for (int i = 0; i < populationSize; i++) { @@ -62,7 +81,13 @@ public class Trainer { return population; } - + /** + * + * @param solutions + * @param data + * @param expected + * @return + */ public List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) { for (SolutionDistance solution : solutions) { solution.computeScoreTo(data, expected); @@ -71,6 +96,14 @@ public class Trainer { } + /** + * + * @param solutions + * @param selected_elements + * @param count + * @return + */ + @Contract("_, _, _ -> param2") private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { double max = Utils.findMaxDistance(solutions); double min = Utils.findMinDistance(solutions); @@ -95,6 +128,14 @@ public class Trainer { return selected_elements; } + /** + * + * @param solutions + * @param selected_elements + * @param count + * @return + */ + @Contract("_, _, _ -> param2") private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { while (selected_elements.size() < count) { @@ -121,6 +162,13 @@ public class Trainer { return selected_elements; } + /** + * + * @param solutions + * @param count + * @param selectionMethod + * @return + */ public List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) { List<SolutionDistance> selected_parents = new ArrayList<>(); //Select the two best current solutions @@ -144,6 +192,11 @@ public class Trainer { } + /** + * + * @param solutions + * @return + */ public List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) { int nbr_weights = solutions.get(0).weights_p.length; @@ -167,6 +220,15 @@ public class Trainer { return solutions; } + /** + * + * @param dad + * @param mom + * @param solutions + * @param cutPosition + * @param beta + * @return + */ public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) { double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); @@ -207,6 +269,11 @@ public class Trainer { return solutions; } + /** + * + * @param solutions + * @return + */ public List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) { double probability = this.getParameters().getMutationRate() / 100; @@ -219,16 +286,34 @@ public class Trainer { return solutions; } + /** + * + * @return + */ public TrainerParameters getParameters() { return this.parameters; } + /** + * + * @param populationSize + * @param data + * @param expected + * @return + */ public List<SolutionDistance> generateInitialPopulationAndComputeDistances(int populationSize, List<double[]> data, double[] expected) { int number_of_weights = data.get(0).length; List<SolutionDistance> initial_population = this.generateInitialPopulation(number_of_weights, populationSize); return this.computeDistances(initial_population, data, expected); } + /** + * + * @param population + * @param data + * @param expected + * @return + */ public List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) { List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod()); List<SolutionDistance> new_generation = this.doReproduction(parents); @@ -237,6 +322,10 @@ public class Trainer { return this.computeDistances(mutated, data, expected); } + /** + * + * @param child + */ private void checkAndCorrectNullWeightVector(SolutionDistance child) { } diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 27a5b87..3449769 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -6,7 +6,11 @@ import java.util.List; import java.util.Random; public class Utils { - + /** + * + * @param weights + * @return + */ public static double[] normalizeWeights(final double[] weights) { double sum_weight = Utils.sumArrayElements(weights); double[] weightsNormalized = new double[weights.length]; @@ -16,6 +20,11 @@ public class Utils { return weightsNormalized; } + /** + * + * @param solutions + * @return + */ public static double findMaxDistance(final List<SolutionDistance> solutions) { double max = Double.NEGATIVE_INFINITY; for (SolutionDistance solution : solutions) { @@ -26,6 +35,11 @@ public class Utils { return max; } + /** + * + * @param solutions + * @return + */ public static double findMinDistance(final List<SolutionDistance> solutions) { double min = Double.POSITIVE_INFINITY; for (SolutionDistance solution : solutions) { @@ -36,6 +50,11 @@ public class Utils { return min; } + /** + * + * @param solutions + * @return + */ public static double sumTotalDistance(final List<SolutionDistance> solutions) { double sum = 0; for (SolutionDistance solution : solutions) { @@ -44,6 +63,11 @@ public class Utils { return sum; } + /** + * + * @param array + * @return + */ public static double sumArrayElements(final double[] array) { float sum = 0; for (double weight : array) { @@ -52,6 +76,12 @@ public class Utils { return sum; } + /** + * + * @param min + * @param max + * @return + */ public static int randomInteger(final int min, final int max) { if (min >= max) { throw new IllegalArgumentException("Max must be greater then min"); -- GitLab