From 4a2b5ef78f9ec7802a0aef2d7c9c504c87ea8499 Mon Sep 17 00:00:00 2001
From: "a.croix" <croix.alexandre@gmail.com>
Date: Mon, 18 Feb 2019 16:21:23 +0100
Subject: [PATCH] First version. Need tests

---
 .../be/cylab/java/wowa/training/Example.java  |   6 +
 .../java/wowa/training/SolutionDistance.java  |  16 +-
 .../be/cylab/java/wowa/training/Trainer.java  | 183 ++++++++++++++++--
 .../be/cylab/java/wowa/training/Utils.java    |   7 +-
 4 files changed, 191 insertions(+), 21 deletions(-)

diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java
index a7523b2..f2351b8 100644
--- a/src/main/java/be/cylab/java/wowa/training/Example.java
+++ b/src/main/java/be/cylab/java/wowa/training/Example.java
@@ -2,6 +2,7 @@
 package be.cylab.java.wowa.training;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.logging.Logger;
 
@@ -29,6 +30,11 @@ public class Example {
         for (SolutionDistance solution : population) {
             System.out.println(solution);
         }
+    System.out.println("Trié\n");
+        Collections.sort(population);
+        for (SolutionDistance solution : population) {
+            System.out.println(solution);
+        }
 
 
     }
diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java
index ac03857..3b375c3 100644
--- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java
+++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java
@@ -3,9 +3,10 @@ package be.cylab.java.wowa.training;
 //import info.debatty.java.aggregation.WOWA;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 
-public class SolutionDistance {
+public class SolutionDistance implements Comparable<SolutionDistance> {
 
     public double[] weights_w;
     public double[] weights_p;
@@ -53,14 +54,17 @@ public class SolutionDistance {
 
     }
 
+    @Override
+    public int compareTo(SolutionDistance solution) {
+        return this.getDistance() > solution.getDistance() ? -1 : this.getDistance() < solution.getDistance() ? 1 : 0;
+    }
+
+
     public void normalize() {
         this.weights_w = Utils.normalizeWeights(this.weights_w);
         this.weights_p = Utils.normalizeWeights(this.weights_p);
     }
 
-    public static void sort(List<SolutionDistance> solutions) {
-
-    }
 
     public double[] getWeights_w() {
         return weights_w;
@@ -73,4 +77,8 @@ public class SolutionDistance {
     public double getDistance() {
         return distance;
     }
+
+    public void setDistance(double distance) {
+        this.distance = distance;
+    }
 }
diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index cf287dd..2e7e2dd 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -1,6 +1,7 @@
 package be.cylab.java.wowa.training;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 public class Trainer {
@@ -10,15 +11,43 @@ public class Trainer {
     public Trainer(TrainerParameters parameters) {
         this.parameters = parameters;
     }
-/*
+
     public SolutionDistance run(List<double[]> data, double[] expected) {
+        List<SolutionDistance> current_population = this.generateInitialPopulationAndComputeDistances(this.getParameters().getPopulationSize(), data, expected);
+
+        Collections.sort(current_population);
+        SolutionDistance best_solution = this.findBestSolution(current_population);
+        for (int generation = 0; generation < this.parameters.getMaxGenerationNumber(); generation++) {
+            current_population = this.performReproduction(current_population, data, expected);
+            SolutionDistance best_solution_of_current_population = this.findBestSolution(current_population);
+
+            if (this.getParameters().getLogger() != null) {
+
+            }
+
+            //We found a new best solution
+            if (best_solution_of_current_population.getDistance() < best_solution.getDistance()) {
+                best_solution = best_solution_of_current_population;
+            }
 
+            if (Math.abs(best_solution.getDistance()) < this.getParameters().getTriggerDistance()) {
+                break;
+            }
+        }
+        return best_solution;
     }
 
-    public SolutionDistance findBestSolution(SolutionDistance[] solutions) {
+    public SolutionDistance findBestSolution(List<SolutionDistance> solutions) {
+        SolutionDistance best_solution = new SolutionDistance(solutions.get(0).weights_p.length);
 
+        for (SolutionDistance solution : solutions) {
+            if (solution.getDistance() < best_solution.getDistance()) {
+                best_solution = solution;
+            }
+        }
+        return best_solution;
     }
-*/
+
     public List<SolutionDistance> generateInitialPopulation(int numberOfWeights, int populationSize) {
         List<SolutionDistance> population = new ArrayList<SolutionDistance>();
         for (int i = 0; i < populationSize; i++) {
@@ -33,9 +62,6 @@ public class Trainer {
         return population;
     }
 
-    public double getRandomDouble() {
-        return 0;
-    }
 
     public List<SolutionDistance> computeDistances(List<SolutionDistance> solutions, List<double[]> data, double[] expected) {
         for (SolutionDistance solution : solutions) {
@@ -44,31 +70,156 @@ public class Trainer {
         return solutions;
 
     }
-/*
-    private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, int count) {
 
+    private List<SolutionDistance> rouletteWheelSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count) {
+        double max = Utils.findMaxDistance(solutions);
+        double min = Utils.findMinDistance(solutions);
+        double sum = Utils.sumTotalDistance(solutions);
+
+
+        if (solutions.size() < count) {
+            throw new IllegalArgumentException("Not enough elements in population to select "+ count + "parents");
+        }
+        while (selected_elements.size() < count) {
+            double tos = Math.random();
+            double normalized_distance = 0;
+            for (SolutionDistance solution : solutions) {
+                normalized_distance = normalized_distance + (max + min - solution.getDistance()) / sum;
+                if (normalized_distance > tos) {
+                    selected_elements.add(solution);
+                    solutions.remove(solution);
+                    break;
+                }
+            }
+        }
+        return selected_elements;
     }
 
-    private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, int count){
+    private List<SolutionDistance> tournamentSelection(List<SolutionDistance> solutions, List<SolutionDistance> selected_elements, int count){
+
+        while (selected_elements.size() < count) {
+            Collections.sort(solutions);
+
+            //Select random element in the population
+            int solution1_index = Utils.randomInteger(0, solutions.size());
+            SolutionDistance solution1 = solutions.get(solution1_index);
 
+            //Select random element in the population
+            int solution2_index = Utils.randomInteger(0, solutions.size());
+            SolutionDistance solution2 = solutions.get(solution2_index);
+
+            //Compare the two selected element and put the solution with the smallest distance in selected list
+            if (solution1.getDistance() < solution2.getDistance()) {
+                selected_elements.add(solution1);
+                solutions.remove(solution1);
+            } else {
+                selected_elements.add(solution2);
+                solutions.remove(solution2);
+
+            }
+        }
+        return selected_elements;
     }
 
     public List<SolutionDistance> selectParents(List<SolutionDistance> solutions, int count, int selectionMethod) {
+        List<SolutionDistance> selected_parents = new ArrayList<>();
+        //Select the two best current solutions
+
+        //Sort population
+        Collections.sort(solutions);
+        //Put the two best elements in selected parents list
+        selected_parents.add(solutions.get(0));
+        selected_parents.add(solutions.get(1));
+        //Remove the two best elements from the list solutions
+        solutions.remove(0);
+        solutions.remove(1);
+
+        if(selectionMethod == TrainerParameters.SELECTION_METHOD_RWS) {
+
+            return this.rouletteWheelSelection(solutions, selected_parents, count - 2);
+        }
+        else if (selectionMethod == TrainerParameters.SELECTION_METHOD_TOS) {
+            return this.tournamentSelection(solutions, selected_parents, count - 2);
+        }
+        throw new IllegalArgumentException("Invalid selection method");
 
     }
 
     public List<SolutionDistance> doReproduction(List<SolutionDistance> solutions) {
+        int nbr_weights = solutions.get(0).weights_p.length;
+
+        //we add children to the current list of solutions
+        while (solutions.size() < this.parameters.getPopulationSize()) {
+            int dad = Utils.randomInteger(0, solutions.size() - 1);
+            int mom;
+            do {
+                mom = Utils.randomInteger(0, solutions.size() - 1);
+            } while (dad == mom);
 
+            int cut_position = Utils.randomInteger(0, nbr_weights - 1);
+
+            //beta is used to compute the new value at the cut position
+            double beta = Math.random();
+            solutions = this.reproduce(solutions.get(dad), solutions.get(mom), solutions, cut_position, beta);
+        }
+        while(solutions.size() > this.getParameters().getPopulationSize()) {
+            solutions.remove(solutions.get(solutions.size() - 1));
+        }
+        return solutions;
     }
 
-    public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, int cutPosition, float beta) {
-        //Different compare to PHP implementation !! PHP allows to return several elements !
+    public List<SolutionDistance> reproduce(SolutionDistance dad, SolutionDistance mom, List<SolutionDistance> solutions, int cutPosition, double beta) {
+        double pnew1W = dad.weights_w[cutPosition] - beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]);
+        double pnew2W = mom.weights_w[cutPosition] + beta * (dad.weights_w[cutPosition] - mom.weights_w[cutPosition]);
+
+        double pnew1P = dad.weights_p[cutPosition] - beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]);
+        double pnew2P = mom.weights_p[cutPosition] + beta * (dad.weights_p[cutPosition] - mom.weights_p[cutPosition]);
+
+        SolutionDistance child1 = new SolutionDistance(dad.weights_p.length);
+        SolutionDistance child2 = new SolutionDistance(dad.weights_p.length);
+
+        for (int i = 0; i < cutPosition; i++) {
+            child1.weights_w[i] = dad.weights_w[i];
+            child1.weights_p[i] = dad.weights_p[i];
+
+            child2.weights_w[i] = mom.weights_w[i];
+            child2.weights_p[i] = mom.weights_p[i];
+        }
+
+        child1.weights_w[cutPosition] = pnew1W;
+        child2.weights_w[cutPosition] = pnew2W;
+        child1.weights_p[cutPosition] = pnew1P;
+        child2.weights_p[cutPosition] = pnew2P;
+
+        int nbre_weights = dad.weights_w.length;
+        for (int i = cutPosition + 1; i < nbre_weights; i++) {
+            child1.weights_w[i] = mom.weights_w[i];
+            child1.weights_p[i] = mom.weights_p[i];
+
+            child2.weights_w[i] = dad.weights_w[i];
+            child2.weights_p[i] = dad.weights_p[i];
+        }
+        //Check and correct only if we used a QuasiRandom generation
+        child1.normalize();
+        child2.normalize();
+        solutions.add(child1);
+        solutions.add(child2);
+
+        return solutions;
     }
 
     public List<SolutionDistance> randomlyMutateGenes(List<SolutionDistance> solutions) {
+        double probability = this.getParameters().getMutationRate() / 100;
 
+        for (int i = 0; i < solutions.size(); i++) {
+            if (i > 1) {
+                solutions.get(i).randomlyMutateWithProbability(probability);
+            }
+
+        }
+        return solutions;
     }
-*/
+
     public TrainerParameters getParameters() {
         return this.parameters;
     }
@@ -78,11 +229,15 @@ public class Trainer {
         List<SolutionDistance> initial_population = this.generateInitialPopulation(number_of_weights, populationSize);
         return this.computeDistances(initial_population, data, expected);
     }
-/*
+
     public List<SolutionDistance> performReproduction(List<SolutionDistance> population, List<double[]> data, double[] expected) {
+        List<SolutionDistance> parents = this.selectParents(population, this.getParameters().getNumberParents(), this.getParameters().getSelectionMethod());
+        List<SolutionDistance> new_generation = this.doReproduction(parents);
+        List<SolutionDistance> mutated = this.randomlyMutateGenes(new_generation);
 
+        return this.computeDistances(mutated, data, expected);
     }
-*/
+
     private void checkAndCorrectNullWeightVector(SolutionDistance child) {
 
     }
diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java
index 0f31e02..680f633 100644
--- a/src/main/java/be/cylab/java/wowa/training/Utils.java
+++ b/src/main/java/be/cylab/java/wowa/training/Utils.java
@@ -2,6 +2,7 @@ package be.cylab.java.wowa.training;
 
 import java.lang.reflect.Array;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Random;
 
 public class Utils {
@@ -15,7 +16,7 @@ public class Utils {
         return weightsNormalized;
     }
 
-    public double findMaxDistance(SolutionDistance[] solutions) {
+    public static double findMaxDistance(List<SolutionDistance> solutions) {
         double max = Double.NEGATIVE_INFINITY;
         for (SolutionDistance solution : solutions) {
             if(solution.getDistance() > max) {
@@ -25,7 +26,7 @@ public class Utils {
         return max;
     }
 
-    public double findMinDistance(SolutionDistance[] solutions) {
+    public static double findMinDistance(List<SolutionDistance> solutions) {
         double min = Double.POSITIVE_INFINITY;
         for (SolutionDistance solution : solutions) {
             if (solution.getDistance() < min){
@@ -35,7 +36,7 @@ public class Utils {
         return min;
     }
 
-    public double sumTotalDistance(SolutionDistance[] solutions) {
+    public static double sumTotalDistance(List<SolutionDistance> solutions) {
         double sum = 0;
         for (SolutionDistance solution : solutions) {
             sum += solution.getDistance();
-- 
GitLab