Skip to content
Snippets Groups Projects
Commit 4e2160f8 authored by a.croix's avatar a.croix
Browse files

Add basic Javadoc

parent 8affc6af
No related branches found
No related tags found
No related merge requests found
Pipeline #1013 failed
...@@ -22,5 +22,14 @@ ...@@ -22,5 +22,14 @@
<SOURCES /> <SOURCES />
</library> </library>
</orderEntry> </orderEntry>
<orderEntry type="module-library">
<library>
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/annotations/16.0.2/annotations-16.0.2.jar!/" />
</CLASSES>
<JAVADOC />
<SOURCES />
</library>
</orderEntry>
</component> </component>
</module> </module>
\ No newline at end of file
...@@ -18,6 +18,10 @@ public class SolutionDistance implements Comparable<SolutionDistance> { ...@@ -18,6 +18,10 @@ public class SolutionDistance implements Comparable<SolutionDistance> {
this.weights_p = new double[weights_number]; this.weights_p = new double[weights_number];
} }
/**
*
* @return
*/
@Override @Override
public String toString() { public String toString() {
return "SolutionDistance{" + return "SolutionDistance{" +
...@@ -27,13 +31,21 @@ public class SolutionDistance implements Comparable<SolutionDistance> { ...@@ -27,13 +31,21 @@ public class SolutionDistance implements Comparable<SolutionDistance> {
+ '}'; + '}';
} }
/**
*
* @param data
* @param expected
*/
public void computeScoreTo(List<double[]> data, double[] expected) { public void computeScoreTo(List<double[]> data, double[] expected) {
this.distance = Math.random(); this.distance = Math.random();
} }
/**
*
* @param probability
*/
public void randomlyMutateWithProbability(final double probability) { public void randomlyMutateWithProbability(final double probability) {
double tos = Math.random(); double tos = Math.random();
if (tos > probability) { if (tos > probability) {
...@@ -55,18 +67,28 @@ public class SolutionDistance implements Comparable<SolutionDistance> { ...@@ -55,18 +67,28 @@ public class SolutionDistance implements Comparable<SolutionDistance> {
} }
/**
*
* @param solution
* @return
*/
@Override @Override
public int compareTo(final SolutionDistance solution) { public int compareTo(final SolutionDistance solution) {
return this.getDistance() > solution.getDistance() ? -1 : this.getDistance() < solution.getDistance() ? 1 : 0; return this.getDistance() > solution.getDistance() ? -1 : this.getDistance() < solution.getDistance() ? 1 : 0;
} }
/**
*
*/
public void normalize() { public void normalize() {
this.weights_w = Utils.normalizeWeights(this.weights_w); this.weights_w = Utils.normalizeWeights(this.weights_w);
this.weights_p = Utils.normalizeWeights(this.weights_p); this.weights_p = Utils.normalizeWeights(this.weights_p);
} }
/**
*
* @return
*/
public double getDistance() { public double getDistance() {
return distance; return distance;
} }
......
package be.cylab.java.wowa.training; package be.cylab.java.wowa.training;
import org.jetbrains.annotations.Contract;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
...@@ -12,6 +14,12 @@ public class Trainer { ...@@ -12,6 +14,12 @@ public class Trainer {
this.parameters = parameters; this.parameters = parameters;
} }
/**
*
* @param data
* @param expected
* @return
*/
public SolutionDistance run(final List<double[]> data, final double[] expected) { public SolutionDistance run(final List<double[]> data, final double[] expected) {
List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected); List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected);
...@@ -37,6 +45,11 @@ public class Trainer { ...@@ -37,6 +45,11 @@ public class Trainer {
return best_solution; return best_solution;
} }
/**
*
* @param solutions
* @return
*/
public SolutionDistance findBestSolution(final List<SolutionDistance> solutions) { public SolutionDistance findBestSolution(final List<SolutionDistance> solutions) {
SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length); SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length);
...@@ -48,6 +61,12 @@ public class Trainer { ...@@ -48,6 +61,12 @@ public class Trainer {
return best_solution; return best_solution;
} }
/**
*
* @param numberOfWeights
* @param populationSize
* @return
*/
public List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) { public List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) {
List<SolutionDistance> population = new ArrayList<SolutionDistance>(); List<SolutionDistance> population = new ArrayList<SolutionDistance>();
for (int i = 0; i < populationSize; i++) { for (int i = 0; i < populationSize; i++) {
...@@ -62,7 +81,13 @@ public class Trainer { ...@@ -62,7 +81,13 @@ public class Trainer {
return population; return population;
} }
/**
*
* @param solutions
* @param data
* @param expected
* @return
*/
public List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) { public List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) {
for (SolutionDistance solution : solutions) { for (SolutionDistance solution : solutions) {
solution.computeScoreTo(data, expected); solution.computeScoreTo(data, expected);
...@@ -71,6 +96,14 @@ public class Trainer { ...@@ -71,6 +96,14 @@ public class Trainer {
} }
/**
*
* @param solutions
* @param selected_elements
* @param count
* @return
*/
@Contract("_, _, _ -> param2")
private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) {
double max = Utils.findMaxDistance(solutions); double max = Utils.findMaxDistance(solutions);
double min = Utils.findMinDistance(solutions); double min = Utils.findMinDistance(solutions);
...@@ -95,6 +128,14 @@ public class Trainer { ...@@ -95,6 +128,14 @@ public class Trainer {
return selected_elements; return selected_elements;
} }
/**
*
* @param solutions
* @param selected_elements
* @param count
* @return
*/
@Contract("_, _, _ -> param2")
private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) { private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) {
while (selected_elements.size() < count) { while (selected_elements.size() < count) {
...@@ -121,6 +162,13 @@ public class Trainer { ...@@ -121,6 +162,13 @@ public class Trainer {
return selected_elements; return selected_elements;
} }
/**
*
* @param solutions
* @param count
* @param selectionMethod
* @return
*/
public List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) { public List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) {
List<SolutionDistance> selected_parents = new ArrayList<>(); List<SolutionDistance> selected_parents = new ArrayList<>();
//Select the two best current solutions //Select the two best current solutions
...@@ -144,6 +192,11 @@ public class Trainer { ...@@ -144,6 +192,11 @@ public class Trainer {
} }
/**
*
* @param solutions
* @return
*/
public List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) { public List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) {
int nbr_weights = solutions.get(0).weights_p.length; int nbr_weights = solutions.get(0).weights_p.length;
...@@ -167,6 +220,15 @@ public class Trainer { ...@@ -167,6 +220,15 @@ public class Trainer {
return solutions; return solutions;
} }
/**
*
* @param dad
* @param mom
* @param solutions
* @param cutPosition
* @param beta
* @return
*/
public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) { public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) {
double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]);
double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]); double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]);
...@@ -207,6 +269,11 @@ public class Trainer { ...@@ -207,6 +269,11 @@ public class Trainer {
return solutions; return solutions;
} }
/**
*
* @param solutions
* @return
*/
public List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) { public List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) {
double probability = this.getParameters().getMutationRate() / 100; double probability = this.getParameters().getMutationRate() / 100;
...@@ -219,16 +286,34 @@ public class Trainer { ...@@ -219,16 +286,34 @@ public class Trainer {
return solutions; return solutions;
} }
/**
*
* @return
*/
public TrainerParameters getParameters() { public TrainerParameters getParameters() {
return this.parameters; return this.parameters;
} }
/**
*
* @param populationSize
* @param data
* @param expected
* @return
*/
public List<SolutionDistance> generateInitialPopulationAndComputeDistances(int populationSize, List<double[]> data, double[] expected) { public List<SolutionDistance> generateInitialPopulationAndComputeDistances(int populationSize, List<double[]> data, double[] expected) {
int number_of_weights = data.get(0).length; int number_of_weights = data.get(0).length;
List<SolutionDistance> initial_population = this.generateInitialPopulation(number_of_weights, populationSize); List<SolutionDistance> initial_population = this.generateInitialPopulation(number_of_weights, populationSize);
return this.computeDistances(initial_population, data, expected); return this.computeDistances(initial_population, data, expected);
} }
/**
*
* @param population
* @param data
* @param expected
* @return
*/
public List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) { public List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) {
List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod()); List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod());
List<SolutionDistance> new_generation = this.doReproduction(parents); List<SolutionDistance> new_generation = this.doReproduction(parents);
...@@ -237,6 +322,10 @@ public class Trainer { ...@@ -237,6 +322,10 @@ public class Trainer {
return this.computeDistances(mutated, data, expected); return this.computeDistances(mutated, data, expected);
} }
/**
*
* @param child
*/
private void checkAndCorrectNullWeightVector(SolutionDistance child) { private void checkAndCorrectNullWeightVector(SolutionDistance child) {
} }
......
...@@ -6,7 +6,11 @@ import java.util.List; ...@@ -6,7 +6,11 @@ import java.util.List;
import java.util.Random; import java.util.Random;
public class Utils { public class Utils {
/**
*
* @param weights
* @return
*/
public static double[] normalizeWeights(final double[] weights) { public static double[] normalizeWeights(final double[] weights) {
double sum_weight = Utils.sumArrayElements(weights); double sum_weight = Utils.sumArrayElements(weights);
double[] weightsNormalized = new double[weights.length]; double[] weightsNormalized = new double[weights.length];
...@@ -16,6 +20,11 @@ public class Utils { ...@@ -16,6 +20,11 @@ public class Utils {
return weightsNormalized; return weightsNormalized;
} }
/**
*
* @param solutions
* @return
*/
public static double findMaxDistance(final List<SolutionDistance> solutions) { public static double findMaxDistance(final List<SolutionDistance> solutions) {
double max = Double.NEGATIVE_INFINITY; double max = Double.NEGATIVE_INFINITY;
for (SolutionDistance solution : solutions) { for (SolutionDistance solution : solutions) {
...@@ -26,6 +35,11 @@ public class Utils { ...@@ -26,6 +35,11 @@ public class Utils {
return max; return max;
} }
/**
*
* @param solutions
* @return
*/
public static double findMinDistance(final List<SolutionDistance> solutions) { public static double findMinDistance(final List<SolutionDistance> solutions) {
double min = Double.POSITIVE_INFINITY; double min = Double.POSITIVE_INFINITY;
for (SolutionDistance solution : solutions) { for (SolutionDistance solution : solutions) {
...@@ -36,6 +50,11 @@ public class Utils { ...@@ -36,6 +50,11 @@ public class Utils {
return min; return min;
} }
/**
*
* @param solutions
* @return
*/
public static double sumTotalDistance(final List<SolutionDistance> solutions) { public static double sumTotalDistance(final List<SolutionDistance> solutions) {
double sum = 0; double sum = 0;
for (SolutionDistance solution : solutions) { for (SolutionDistance solution : solutions) {
...@@ -44,6 +63,11 @@ public class Utils { ...@@ -44,6 +63,11 @@ public class Utils {
return sum; return sum;
} }
/**
*
* @param array
* @return
*/
public static double sumArrayElements(final double[] array) { public static double sumArrayElements(final double[] array) {
float sum = 0; float sum = 0;
for (double weight : array) { for (double weight : array) {
...@@ -52,6 +76,12 @@ public class Utils { ...@@ -52,6 +76,12 @@ public class Utils {
return sum; return sum;
} }
/**
*
* @param min
* @param max
* @return
*/
public static int randomInteger(final int min, final int max) { public static int randomInteger(final int min, final int max) {
if (min >= max) { if (min >= max) {
throw new IllegalArgumentException("Max must be greater then min"); throw new IllegalArgumentException("Max must be greater then min");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment