Skip to content
Snippets Groups Projects
Commit e93a54a3 authored by a.croix's avatar a.croix
Browse files

Unit tests for SolutionDistance + Trainer

parent af6677d4
No related branches found
No related tags found
No related merge requests found
Pipeline #1050 passed
......@@ -9,5 +9,5 @@ image: maven:3.5.3-jdk-8
mvn:package:
script:
- mvn clean verify
- mvn clean test
......@@ -53,7 +53,10 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.1</version>
</plugin>
</plugins>
</build>
<properties>
......
......@@ -14,16 +14,16 @@ public class Example {
public static void main(String[] args) {
Logger LOGGER = Logger.getLogger(Trainer.class.getName());
LOGGER.setLevel(Level.INFO);
TrainerParameters parameters = new TrainerParameters(LOGGER, 10, 60, 10, TrainerParameters.SELECTION_METHOD_RWS, 100);
TrainerParameters parameters = new TrainerParameters(LOGGER, 200, 60, 10, TrainerParameters.SELECTION_METHOD_RWS, 100);
Trainer trainer = new Trainer(parameters);
List<double[]> data = new ArrayList<>();
for (int i = 0; i < 5; i++) {
for (int i = 0; i < 1500; i++) {
double[] element = {Math.random(), Math.random(), Math.random()};
data.add(element);
}
double[] expected = new double[5];
double[] expected = new double[1500];
for (int j = 0; j < 5; j++) {
for (int j = 0; j < 1500; j++) {
expected[j] = Math.random();
}
......
......@@ -112,9 +112,9 @@ public class SolutionDistance implements Comparable<SolutionDistance>, Cloneable
@Override
public int compareTo(final SolutionDistance solution) {
if (this.getDistance() > solution.getDistance()) {
return -1;
} else if (this.getDistance() < solution.getDistance()) {
return 1;
} else if (this.getDistance() < solution.getDistance()) {
return -1;
} else {
return 0;
}
......
......@@ -51,7 +51,7 @@ public class Trainer {
this.findBestSolution(current_population);
if (this.getParameters().getLogger() != null) {
this.getParameters().getLogger().log(Level.WARNING, best_solution.toString());
this.getParameters().getLogger().log(Level.INFO, "Generation " + (generation+1) + " : " + best_solution.toString());
}
//We found a new best solution
......@@ -68,11 +68,10 @@ public class Trainer {
* @param solutions
* @return
*/
private SolutionDistance findBestSolution(
SolutionDistance findBestSolution(
final List<SolutionDistance> solutions
) {
SolutionDistance best_solution =
new SolutionDistance(solutions.get(0).getWeightsP().length);
SolutionDistance best_solution = solutions.get(0);
for (SolutionDistance solution : solutions) {
if (solution.getDistance() < best_solution.getDistance()) {
......@@ -87,7 +86,7 @@ public class Trainer {
* @param population_size
* @return
*/
private List<SolutionDistance> generateInitialPopulation(
List<SolutionDistance> generateInitialPopulation(
final int number_of_weights,
final int population_size
) {
......@@ -105,7 +104,7 @@ public class Trainer {
* @param expected
* @return
*/
private List<SolutionDistance> computeDistances(
List<SolutionDistance> computeDistances(
final List<SolutionDistance> solutions,
final List<double[]> data,
final double[] expected
......@@ -123,7 +122,7 @@ public class Trainer {
* @param count
* @return
*/
private List<SolutionDistance> rouletteWheelSelection(
List<SolutionDistance> rouletteWheelSelection(
final List<SolutionDistance> solutions,
final List<SolutionDistance> selected_elements,
final int count
......@@ -161,7 +160,7 @@ public class Trainer {
* @param count
* @return
*/
private List<SolutionDistance> tournamentSelection(
List<SolutionDistance> tournamentSelection(
final List<SolutionDistance> solutions,
final List<SolutionDistance> selected_elements,
final int count
......@@ -171,11 +170,11 @@ public class Trainer {
Collections.sort(solutions);
//Select random element in the population
int solution1_index = Utils.randomInteger(0, solutions.size());
int solution1_index = Utils.randomInteger(0, solutions.size() - 1);
SolutionDistance solution1 = solutions.get(solution1_index);
//Select random element in the population
int solution2_index = Utils.randomInteger(0, solutions.size());
int solution2_index = Utils.randomInteger(0, solutions.size() - 1);
SolutionDistance solution2 = solutions.get(solution2_index);
//Compare the two selected element and put the solution with
......@@ -198,7 +197,7 @@ public class Trainer {
* @param selection_method
* @return
*/
private List<SolutionDistance> selectParents(
List<SolutionDistance> selectParents(
final List<SolutionDistance> solutions,
final int count,
final int selection_method
......@@ -219,12 +218,12 @@ public class Trainer {
return this.rouletteWheelSelection(solutions,
selected_parents,
count - 2
count
);
} else if (selection_method == TrainerParameters.SELECTION_METHOD_TOS) {
return this.tournamentSelection(solutions,
selected_parents,
count - 2
count
);
}
throw new IllegalArgumentException("Invalid selection method");
......@@ -235,7 +234,7 @@ public class Trainer {
* @param solutions
* @return
*/
private List<SolutionDistance> doReproduction(
List<SolutionDistance> doReproduction(
final List<SolutionDistance> solutions
) {
int nbr_weights = solutions.get(0).getWeightsP().length;
......@@ -273,7 +272,7 @@ public class Trainer {
* @param beta
* @return
*/
private void reproduce(final SolutionDistance dad,
void reproduce(final SolutionDistance dad,
final SolutionDistance mom,
final List<SolutionDistance> solutions,
final int cut_position,
......@@ -331,7 +330,7 @@ public class Trainer {
* @param solutions
* @return
*/
private List<SolutionDistance> randomlyMutateGenes(
List<SolutionDistance> randomlyMutateGenes(
final List<SolutionDistance> solutions
) {
double probability = this.getParameters().getMutationRate() / 100;
......@@ -348,7 +347,7 @@ public class Trainer {
/**
* @return
*/
private TrainerParameters getParameters() {
TrainerParameters getParameters() {
return this.parameters;
}
......@@ -378,7 +377,7 @@ public class Trainer {
* @param expected
* @return
*/
private List<SolutionDistance> performReproduction(
List<SolutionDistance> performReproduction(
final List<SolutionDistance> population,
final List<double[]> data,
final double[] expected
......
......@@ -68,14 +68,14 @@ public class TrainerParameters {
* Getter for number of parents.
* @return int
*/
final int getNumberOfParents() {
int getNumberOfParents() {
return this.number_of_parents;
}
/**
* @return Logger
*/
final Logger getLogger() {
Logger getLogger() {
return logger;
}
......@@ -84,8 +84,8 @@ public class TrainerParameters {
* Getter for population_size.
* @return int
*/
final int getPopulationSize() {
return population_size;
int getPopulationSize() {
return this.population_size;
}
/**
......@@ -110,7 +110,7 @@ public class TrainerParameters {
private void setCrossoverRate(final int crossover_rate) {
this.crossover_rate = crossover_rate;
int nbr_parents = Math.round(this.population_size
* (1 - crossover_rate / 100));
* (1 - (float)crossover_rate / 100));
if (nbr_parents % 2 == 1) {
nbr_parents++;
}
......
......@@ -8,10 +8,10 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
class SolutionDistanceTest {
class SolutionDistanceTest {
@Test
void generateSolutionDistance() {
@org.junit.jupiter.api.Test
public void testGenerateSolutionDistance() {
SolutionDistance solution = new SolutionDistance(5, 1234);
double actualW = Utils.sumArrayElements(solution.getWeightsW());
double actualP = Utils.sumArrayElements(solution.getWeightsP());
......@@ -21,8 +21,8 @@ class SolutionDistanceTest {
assertEquals(5, solution.getWeightsP().length);
}
@Test
void computeScoreTo() {
@org.junit.jupiter.api.Test
public void computeScoreTo() {
SolutionDistance solution = new SolutionDistance(5,1234);
List<double[]> data = generateData(20, 5);
double[] expected = generateExpected(20);
......
package be.cylab.java.wowa.training;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import static org.junit.jupiter.api.Assertions.*;
class TrainerTest {
private Trainer trainer;
@BeforeEach
void setUp() {
TrainerParameters parameters = new TrainerParameters(null, 100, 30, 10, TrainerParameters.SELECTION_METHOD_RWS, 100);
this.trainer = new Trainer(parameters);
}
@Test
void testRun() {
}
@Test
void testFindBestSolution() {
List<SolutionDistance> population = new ArrayList<>();
for (int i = 0; i < 100; i++) {
SolutionDistance solution = new SolutionDistance(5, i);
population.add(solution);
}
List<double[]> data = generateData(100, 5);
double[] expected = generateExpected(100);
List<SolutionDistance> computed_population = this.trainer.computeDistances(population, data, expected);
SolutionDistance bestSolution = this.trainer.findBestSolution(computed_population);
assertEquals(3.045715472368455, bestSolution.getDistance());
}
@Test
void testGenerateInitialPopulation() {
List<SolutionDistance> population = this.trainer.generateInitialPopulation(5, 100);
assertEquals(100, population.size());
for (SolutionDistance solution : population) {
assertEquals(5, solution.getWeightsW().length);
assertEquals(5, solution.getWeightsP().length);
assertEquals(Double.POSITIVE_INFINITY, solution.getDistance());
}
}
@Test
void testComputeDistances() {
}
@Test
void testSelectParents() {
List<SolutionDistance> population = this.trainer.generateInitialPopulation(5, 100);
List<double[]> data = generateData(100, 5);
double[] expected = generateExpected(100);
List<SolutionDistance> computed_population = this.trainer.computeDistances(population, data, expected);
SolutionDistance best_solution = this.trainer.findBestSolution(computed_population);
List<SolutionDistance> parents = this.trainer.selectParents(computed_population, 30, TrainerParameters.SELECTION_METHOD_TOS);
assertEquals(best_solution, parents.get(0));
assertEquals(30, parents.size());
computed_population = this.trainer.computeDistances(population, data, expected);
best_solution = this.trainer.findBestSolution(computed_population);
parents = this.trainer.selectParents(computed_population, 30, TrainerParameters.SELECTION_METHOD_RWS);
assertEquals(best_solution, parents.get(0));
assertEquals(30, parents.size());
}
@Test
void doReproduction() {
}
@Test
void reproduce() {
List<SolutionDistance> population = new ArrayList<>();
for (int i = 0; i < this.trainer.getParameters().getPopulationSize(); i++) {
SolutionDistance solution = new SolutionDistance(5, 2*i);
population.add(solution);
}
List<double[]> data = generateData(100, 5);
double[] expected = generateExpected(100);
population = this.trainer.computeDistances(population, data, expected);
List<SolutionDistance> parents = this.trainer.selectParents(population, 10, TrainerParameters.SELECTION_METHOD_RWS);
this.trainer.reproduce(parents.get(0), parents.get(1), parents, 1, 0.25875);
double[] expected_weight_w_child_1 = new double[] {0.1987717620825617, 0.154039554, 0.26261973461084054, 0.14551847852505223, 0.18295360349083772};
double[] expected_weight_w_child_2 = new double[] {0.28709019832051996, 0.133065638, 0.2628906777282328, 0.15925182144723837, 0.2137985176214255};
double[] expected_weight_p_child_1 = new double[] {0.3119444801528709, 0.130874978, 0.3091280121036283, 0.0943591649660173, 0.26790380888833576};
double[] expected_weight_p_child_2 = new double[] {0.23379455171152877, 0.107402174, 0.0550712506540036, 0.22911863245792205, 0.2604029799761541};
expected_weight_w_child_1 = Utils.normalizeWeights(expected_weight_w_child_1);
expected_weight_w_child_2 = Utils.normalizeWeights(expected_weight_w_child_2);
expected_weight_p_child_1 = Utils.normalizeWeights(expected_weight_p_child_1);
expected_weight_p_child_2 = Utils.normalizeWeights(expected_weight_p_child_2);
assertArrayEquals(parents.get(10).getWeightsW(), expected_weight_w_child_1, 0.0001);
assertArrayEquals(parents.get(10).getWeightsP(), expected_weight_p_child_1, 0.0001);
assertArrayEquals(parents.get(11).getWeightsW(), expected_weight_w_child_2, 0.0001);
assertArrayEquals(parents.get(11).getWeightsP(), expected_weight_p_child_2, 0.0001);
}
@Test
void randomlyMutateGenes() {
}
static List<double[]> generateData(final int size, final int weight_number) {
Random rnd = new Random(5489);
List<double[]> data = new ArrayList<>();
for (int i = 0; i < size; i++) {
double[] vector = new double[weight_number];
for (int j = 0; j < weight_number; j++) {
vector[j] = rnd.nextDouble();
}
data.add(vector);
}
return data;
}
static double[] generateExpected(final int size) {
Random rnd = new Random(5768);
double[] expected = new double[size];
for (int i = 0; i < size; i++) {
expected[i] = rnd.nextDouble();
}
return expected;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment