Skip to content
Snippets Groups Projects
Commit 7afa3534 authored by a.croix's avatar a.croix
Browse files

Change distance to fitness_score

parent d4b89ac0
No related branches found
No related tags found
1 merge request!2Add AUC as a score for distance
Pipeline #1334 passed
...@@ -12,7 +12,7 @@ public abstract class AbstractSolution ...@@ -12,7 +12,7 @@ public abstract class AbstractSolution
protected double[] weights_w; protected double[] weights_w;
protected double[] weights_p; protected double[] weights_p;
protected double distance = Double.POSITIVE_INFINITY; protected double fitness_score = Double.POSITIVE_INFINITY;
/** /**
* SolutionDistance constructor. Needs weight number as parameter * SolutionDistance constructor. Needs weight number as parameter
...@@ -58,7 +58,7 @@ public abstract class AbstractSolution ...@@ -58,7 +58,7 @@ public abstract class AbstractSolution
return "SolutionDistance{" return "SolutionDistance{"
+ "weights_w=" + Arrays.toString(weights_w) + "weights_w=" + Arrays.toString(weights_w)
+ ", weights_p=" + Arrays.toString(weights_p) + ", weights_p=" + Arrays.toString(weights_p)
+ ", distance=" + Math.abs(distance) + ", fitness score=" + Math.abs(fitness_score)
+ '}'; + '}';
} }
...@@ -95,9 +95,9 @@ public abstract class AbstractSolution ...@@ -95,9 +95,9 @@ public abstract class AbstractSolution
*/ */
@Override @Override
public final int compareTo(final AbstractSolution solution) { public final int compareTo(final AbstractSolution solution) {
if (this.getDistance() > solution.getDistance()) { if (this.getFitnessScore() > solution.getFitnessScore()) {
return 1; return 1;
} else if (this.getDistance() < solution.getDistance()) { } else if (this.getFitnessScore() < solution.getFitnessScore()) {
return -1; return -1;
} else { } else {
return 0; return 0;
...@@ -117,14 +117,14 @@ public abstract class AbstractSolution ...@@ -117,14 +117,14 @@ public abstract class AbstractSolution
/** /**
* @return * @return
*/ */
public final double getDistance() { public final double getFitnessScore() {
return distance; return fitness_score;
} }
/** /**
* @param distance * @param distance
*/ */
public final void setDistance(final double distance) { public final void setFitnessScore(final double distance) {
return; return;
} }
......
...@@ -20,9 +20,9 @@ public class SolutionAUC extends AbstractSolution { ...@@ -20,9 +20,9 @@ public class SolutionAUC extends AbstractSolution {
final void computeScoreTo( final void computeScoreTo(
final List<double[]> data, final List<double[]> data,
final double[] expected) { final double[] expected) {
this.distance = 0; this.fitness_score = 0;
double[] aggregated_values = new double[data.size()]; double[] aggregated_values = new double[data.size()];
WOWA wowa = new WOWA(this.weights_w, this.weights_p); WOWA wowa = new WOWA(this.weights_w, this.weights_p);
this.distance = -(Utils.computeAUC(this, data, expected)); this.fitness_score = -(Utils.computeAUC(this, data, expected));
} }
} }
...@@ -43,15 +43,15 @@ public class SolutionDistance extends AbstractSolution { ...@@ -43,15 +43,15 @@ public class SolutionDistance extends AbstractSolution {
final double[] expected final double[] expected
) { ) {
this.distance = 0; this.fitness_score = 0;
for (int i = 0; i < data.size(); i++) { for (int i = 0; i < data.size(); i++) {
double[] vector = data.get(i); double[] vector = data.get(i);
double target_value = expected[i]; double target_value = expected[i];
WOWA wowa = new WOWA(this.weights_w, this.weights_p); WOWA wowa = new WOWA(this.weights_w, this.weights_p);
double aggregated_value = wowa.aggregate(vector); double aggregated_value = wowa.aggregate(vector);
this.distance += Math.pow(target_value - aggregated_value, 2); this.fitness_score += Math.pow(target_value - aggregated_value, 2);
} }
this.distance = Math.sqrt(this.distance); this.fitness_score = Math.sqrt(this.fitness_score);
} }
......
...@@ -58,8 +58,8 @@ public class Trainer { ...@@ -58,8 +58,8 @@ public class Trainer {
} }
//We found a new best solution //We found a new best solution
if (best_solution_of_current_population.getDistance() if (best_solution_of_current_population.getFitnessScore()
< best_solution.getDistance()) { < best_solution.getFitnessScore()) {
best_solution = best_solution_of_current_population; best_solution = best_solution_of_current_population;
} }
...@@ -77,7 +77,7 @@ public class Trainer { ...@@ -77,7 +77,7 @@ public class Trainer {
AbstractSolution best_solution = solutions.get(0); AbstractSolution best_solution = solutions.get(0);
for (AbstractSolution solution : solutions) { for (AbstractSolution solution : solutions) {
if (solution.getDistance() < best_solution.getDistance()) { if (solution.getFitnessScore() < best_solution.getFitnessScore()) {
best_solution = solution; best_solution = solution;
} }
} }
...@@ -192,7 +192,7 @@ public class Trainer { ...@@ -192,7 +192,7 @@ public class Trainer {
double normalized_distance = 0; double normalized_distance = 0;
for (AbstractSolution solution : solutions) { for (AbstractSolution solution : solutions) {
normalized_distance = normalized_distance normalized_distance = normalized_distance
+ (max + min - solution.getDistance()) / sum; + (max + min - solution.getFitnessScore()) / sum;
if (normalized_distance > tos) { if (normalized_distance > tos) {
selected_elements.add(solution); selected_elements.add(solution);
solutions.remove(solution); solutions.remove(solution);
...@@ -230,7 +230,7 @@ public class Trainer { ...@@ -230,7 +230,7 @@ public class Trainer {
//Compare the two selected element and put the solution with //Compare the two selected element and put the solution with
// the smallest distance in selected list // the smallest distance in selected list
if (solution1.getDistance() < solution2.getDistance()) { if (solution1.getFitnessScore() < solution2.getFitnessScore()) {
selected_elements.add(solution1); selected_elements.add(solution1);
solutions.remove(solution1); solutions.remove(solution1);
} else { } else {
......
...@@ -47,8 +47,8 @@ public final class Utils { ...@@ -47,8 +47,8 @@ public final class Utils {
final List<AbstractSolution> solutions) { final List<AbstractSolution> solutions) {
double max = Double.NEGATIVE_INFINITY; double max = Double.NEGATIVE_INFINITY;
for (AbstractSolution solution : solutions) { for (AbstractSolution solution : solutions) {
if (solution.getDistance() > max) { if (solution.getFitnessScore() > max) {
max = solution.getDistance(); max = solution.getFitnessScore();
} }
} }
return max; return max;
...@@ -63,8 +63,8 @@ public final class Utils { ...@@ -63,8 +63,8 @@ public final class Utils {
final List<AbstractSolution> solutions) { final List<AbstractSolution> solutions) {
double min = Double.POSITIVE_INFINITY; double min = Double.POSITIVE_INFINITY;
for (AbstractSolution solution : solutions) { for (AbstractSolution solution : solutions) {
if (solution.getDistance() < min) { if (solution.getFitnessScore() < min) {
min = solution.getDistance(); min = solution.getFitnessScore();
} }
} }
return min; return min;
...@@ -79,7 +79,7 @@ public final class Utils { ...@@ -79,7 +79,7 @@ public final class Utils {
final List<AbstractSolution> solutions) { final List<AbstractSolution> solutions) {
double sum = 0; double sum = 0;
for (AbstractSolution solution : solutions) { for (AbstractSolution solution : solutions) {
sum += solution.getDistance(); sum += solution.getFitnessScore();
} }
return sum; return sum;
} }
......
...@@ -27,7 +27,7 @@ class SolutionDistanceTest { ...@@ -27,7 +27,7 @@ class SolutionDistanceTest {
List<double[]> data = generateData(20, 5); List<double[]> data = generateData(20, 5);
double[] expected = generateExpected(20); double[] expected = generateExpected(20);
solution.computeScoreTo(data, expected); solution.computeScoreTo(data, expected);
assertEquals(1.5925246410672433, solution.getDistance()); assertEquals(1.5925246410672433, solution.getFitnessScore());
} }
@Test @Test
......
...@@ -36,7 +36,7 @@ class TrainerTest { ...@@ -36,7 +36,7 @@ class TrainerTest {
double[] expected = generateExpected(100); double[] expected = generateExpected(100);
List<AbstractSolution> computed_population = this.trainer.computeDistances(population, data, expected); List<AbstractSolution> computed_population = this.trainer.computeDistances(population, data, expected);
AbstractSolution bestSolution = this.trainer.findBestSolution(computed_population); AbstractSolution bestSolution = this.trainer.findBestSolution(computed_population);
assertEquals(3.0482902643223135, bestSolution.getDistance()); assertEquals(3.0482902643223135, bestSolution.getFitnessScore());
} }
@Test @Test
...@@ -46,7 +46,7 @@ class TrainerTest { ...@@ -46,7 +46,7 @@ class TrainerTest {
for (AbstractSolution solution : population) { for (AbstractSolution solution : population) {
assertEquals(5, solution.getWeightsW().length); assertEquals(5, solution.getWeightsW().length);
assertEquals(5, solution.getWeightsP().length); assertEquals(5, solution.getWeightsP().length);
assertEquals(Double.POSITIVE_INFINITY, solution.getDistance()); assertEquals(Double.POSITIVE_INFINITY, solution.getFitnessScore());
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment