Skip to content
Snippets Groups Projects
Commit 8243849f authored by a.croix's avatar a.croix
Browse files

Refactoring: move computeAUC and computeRocPoints from Utils to AbstractSolution. Update README

parent 9e2abf97
No related branches found
No related tags found
No related merge requests found
Pipeline #1722 passed
...@@ -19,7 +19,7 @@ Using maven : ...@@ -19,7 +19,7 @@ Using maven :
<dependency> <dependency>
<groupId>be.cylab</groupId> <groupId>be.cylab</groupId>
<artifactId>java-wowa-training</artifactId> <artifactId>java-wowa-training</artifactId>
<version>0.0.3</version> <version>0.0.4</version>
</dependency> </dependency>
``` ```
...@@ -42,21 +42,27 @@ public static void main(String[] args) { ...@@ -42,21 +42,27 @@ public static void main(String[] args) {
TrainerParameters parameters = new TrainerParameters(logger, population_size, TrainerParameters parameters = new TrainerParameters(logger, population_size,
crossover_rate, mutation_rate, max_generation, selection_method, generation_population_method); crossover_rate, mutation_rate, max_generation, selection_method, generation_population_method);
Trainer trainer = new Trainer(parameters);
//Input data //Input data
List<double[]> data = new ArrayList<double[]>(); List<List<Double>> data = new ArrayList<>();
data.add(new double[] {0.1, 0.2, 0.3, 0.4}); data.add(new ArrayList<>(Arrays.asList(0.1, 0.2, 0.3, 0.4)));
data.add(new double[] {0.1, 0.8, 0.3, 0.4}); data.add(new ArrayList<>(Arrays.asList(0.1, 0.8, 0.3, 0.4)));
data.add(new double[] {0.2, 0.6, 0.3, 0.4}); data.add(new ArrayList<>(Arrays.asList(0.2, 0.6, 0.3, 0.4)));
data.add(new double[] {0.1, 0.2, 0.5, 0.8}); data.add(new ArrayList<>(Arrays.asList(0.1, 0.2, 0.5, 0.8)));
data.add(new double[] {0.5, 0.1, 0.2, 0.3}); data.add(new ArrayList<>(Arrays.asList(0.5, 0.1, 0.2, 0.3)));
data.add(new double[] {0.1, 0.1, 0.1, 0.1}); data.add(new ArrayList<>(Arrays.asList(0.1, 0.1, 0.1, 0.1)));
//Expected aggregated value for each data vector //Expected aggregated value for each data vector
double[] expected = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; List<Double> expected = new ArrayList<>(Arrays.asList(0.1, 0.2, 0.3, 0.4, 0.5, 0.6));
//Create object for the type of Solution (fitness score evaluation)
SolutionDistance solution_type = new SolutionDistance(data.get(0).size());
//Create trainer object
Trainer trainer = new Trainer(parameters, solution_type);
SolutionDistance solution = trainer.run(data, expected); AbstractSolution solution = trainer.run(data, expected);
//Display solution //Display solution
System.out.println(solution); System.out.println(solution);
...@@ -77,12 +83,27 @@ The **run** method returns a solution object, consisting of p weights and w weig ...@@ -77,12 +83,27 @@ The **run** method returns a solution object, consisting of p weights and w weig
### Parameters description ### Parameters description
- **population_size** : size of the population in the algorithm. Suggested value : 100 - **population_size** : size of the population in the algorithm. Suggested value : 100
- **crossover_rate** : defines the percentage of population generated by crossover. Must be between 1 and 100. Suggested value : 60 - **crossover_rate** : defines the percentage of population generated by crossover. Must be between 1 and 100. Suggested value : 60
- **mutation_rate** : define the probability of random element change in the population. Must be between 1 and 100. Suggested value : 15 - **mutation_rate** : define the probability of random element change in the population. Must be between 1 and 100. Suggested value : 15
- **selection_method** : Determine the method used to select element in the population (for generate the next generation). SELECTION_METHOD_RWS for Roulette Wheel Selection and SELECTION_METHOD_TOS for Tournament Selection. - **selection_method** : Determine the method used to select element in the population (for generate the next generation). SELECTION_METHOD_RWS for Roulette Wheel Selection and SELECTION_METHOD_TOS for Tournament Selection.
- **max_generation** : Determine the maximum number of iteration of the algorithm. - **max_generation** : Determine the maximum number of iteration of the algorithm.
- **generation_population_method**: Determine the method used to generate the initial population. POPULATION_INITIALIZATION_RANDOM for a full random initialization and POPULATION_INITIALIZATION_QUASI_RANDOM for a population with specific elements. - **generation_population_method**: Determine the method used to generate the initial population. POPULATION_INITIALIZATION_RANDOM for a full random initialization and POPULATION_INITIALIZATION_QUASI_RANDOM for a population with specific elements.
### Solution type
The algorithm is built to be used with different methods to evaluate the fitness score of each chromosome. Two different criteria are already implemented : *distance* and *AUC*.
- **Distance**: for each element in the population, the WOWA function us computed on all examples of the dataset. hen, the difference between the WOWA result just computed and the result given by the training dataset. All these differences are added to obtain the distance that is the fitness score of a chromosome.The smallest is the distance, the best is the chromosome.
- **AUC*: the Area Under the Curve (AUC) fitness score is designed for binary classification. The obtain the AUC, the Receiver Operating Characteristics (ROC) is built first. Concretely, the WOWA function is computed on all elements of the training dataset. Then, on these results, the ROC curve is built. The AUC of this ROC curve is the fitness score of an element. The biggest is the AUC, the best is the chromosome.
It is possible to create new Solution type with new evaluation criterion. The new Solution type must inherit of *AbstractSolution* class and override the method *computeScoreTo*. It is also necessary to modify the method *createSolutionObject* method in the *Factory* class.
## References ## References
- [The WOWA operator : a review (V. Torra)](https://gitlab.cylab.be/cylab/wowa-training/raw/c3c3785c767ab8258df0fc585aec1e8d463851cd/doc/Torra%20-%202011%20-%20The%20WOWA%20Operator%20A%20Review.1007_978-3.pdf) - [The WOWA operator : a review (V. Torra)](https://gitlab.cylab.be/cylab/wowa-training/raw/c3c3785c767ab8258df0fc585aec1e8d463851cd/doc/Torra%20-%202011%20-%20The%20WOWA%20Operator%20A%20Review.1007_978-3.pdf)
......
package be.cylab.java.wowa.training; package be.cylab.java.wowa.training;
import be.cylab.java.roc.Roc;
import be.cylab.java.roc.RocCoordinates;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.Date;
/** /**
* Abstract class Solution. * Abstract class Solution.
...@@ -91,6 +98,55 @@ public abstract class AbstractSolution ...@@ -91,6 +98,55 @@ public abstract class AbstractSolution
} }
/**
* @param data
* @param expected
* @return
*/
final List<RocCoordinates> computeRocPoints(
final List<List<Double>> data,
final List<Double> expected,
final boolean save_on_csv) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and expected have different size");
}
boolean[] true_alert
= Utils.convertExpectedToBooleanArrayTrueAlert(expected);
double[] score = Utils.computeWOWAScoreWithData(this, data);
Roc roc = new Roc(score, true_alert);
List<RocCoordinates> coordinates = roc.computeRocPoints();
if (save_on_csv) {
DateFormat format = new SimpleDateFormat("yyyyMMdd_HH:mm:ss");
Date date = new Date();
be.cylab.java.roc.Utils.storeRocCoordinatesInCSVFile(coordinates,
"RocCoordinates_" + format.format(date) + ".csv");
}
return coordinates;
}
/**
* @param data
* @param expected
* @return
*/
final double computeAUC(
final List<List<Double>> data,
final List<Double> expected) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and expected have different size");
}
//Convert the array with expected values (0 or 1) to a boolean array.
boolean[] true_alert
= Utils.convertExpectedToBooleanArrayTrueAlert(expected);
//For each elements in dataset, compute the wowa function.
double[] score = Utils.computeWOWAScoreWithData(this, data);
Roc roc = new Roc(score, true_alert);
return roc.computeAUC();
}
/** /**
* @param solution * @param solution
* @return int * @return int
......
...@@ -15,10 +15,14 @@ public class SolutionAUC extends AbstractSolution { ...@@ -15,10 +15,14 @@ public class SolutionAUC extends AbstractSolution {
super(weight_number); super(weight_number);
} }
SolutionAUC(final int weight_number, final int seed) {
super(weight_number, seed);
}
final void computeScoreTo( final void computeScoreTo(
final List<List<Double>> data, final List<List<Double>> data,
final List<Double> expected) { final List<Double> expected) {
this.fitness_score = 0.0; this.fitness_score = 0.0;
this.fitness_score = -(Utils.computeAUC(this, data, expected)); this.fitness_score = -(super.computeAUC(data, expected));
} }
} }
package be.cylab.java.wowa.training; package be.cylab.java.wowa.training;
import be.cylab.java.roc.Roc;
import be.cylab.java.roc.RocCoordinates;
import com.owlike.genson.GenericType; import com.owlike.genson.GenericType;
import com.owlike.genson.Genson; import com.owlike.genson.Genson;
import info.debatty.java.aggregation.WOWA; import info.debatty.java.aggregation.WOWA;
...@@ -11,10 +9,7 @@ import java.io.IOException; ...@@ -11,10 +9,7 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
...@@ -170,63 +165,14 @@ final class Utils { ...@@ -170,63 +165,14 @@ final class Utils {
return json; return json;
} }
/**
* @param solution
* @param data
* @param expected
* @return
*/
public static double computeAUC(
final AbstractSolution solution,
final List<List<Double>> data,
final List<Double> expected) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and expected have different size");
}
//Convert the array with expected values (0 or 1) to a boolean array.
boolean[] true_alert = convertExpectedToBooleanArrayTrueAlert(expected);
//For each elements in dataset, compute the wowa function.
double[] score = computeWOWAScoreWithData(solution, data);
Roc roc = new Roc(score, true_alert);
return roc.computeAUC();
}
/**
* @param solution
* @param data
* @param expected
* @return
*/
public static List<RocCoordinates> computeRocPoints(
final AbstractSolution solution,
final List<List<Double>> data,
final List<Double> expected,
final boolean save_on_csv) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and expected have different size");
}
boolean[] true_alert = convertExpectedToBooleanArrayTrueAlert(expected);
double[] score = computeWOWAScoreWithData(solution, data);
Roc roc = new Roc(score, true_alert);
List<RocCoordinates> coordinates = roc.computeRocPoints();
if (save_on_csv) {
DateFormat format = new SimpleDateFormat("yyyyMMdd_HH:mm:ss");
Date date = new Date();
be.cylab.java.roc.Utils.storeRocCoordinatesInCSVFile(coordinates,
"RocCoordinates_" + format.format(date) + ".csv");
}
return coordinates;
}
/** /**
* @param solution * @param solution
* @param data * @param data
* @return * @return
*/ */
private static double[] computeWOWAScoreWithData( static double[] computeWOWAScoreWithData(
final AbstractSolution solution, final AbstractSolution solution,
final List<List<Double>> data) { final List<List<Double>> data) {
double[] score = new double[data.size()]; double[] score = new double[data.size()];
...@@ -246,7 +192,7 @@ final class Utils { ...@@ -246,7 +192,7 @@ final class Utils {
* @param expected * @param expected
* @return * @return
*/ */
private static boolean[] convertExpectedToBooleanArrayTrueAlert( static boolean[] convertExpectedToBooleanArrayTrueAlert(
final List<Double> expected) { final List<Double> expected) {
boolean[] true_alert = new boolean[expected.size()]; boolean[] true_alert = new boolean[expected.size()];
for (int i = 0; i < expected.size(); i++) { for (int i = 0; i < expected.size(); i++) {
...@@ -268,7 +214,7 @@ final class Utils { ...@@ -268,7 +214,7 @@ final class Utils {
* @param elements * @param elements
* @return * @return
*/ */
public static double[] convertListDoubleToArrayDouble( static double[] convertListDoubleToArrayDouble(
final List<Double> elements) { final List<Double> elements) {
Double[] w = new Double[elements.size()]; Double[] w = new Double[elements.size()];
w = elements.toArray(w); w = elements.toArray(w);
...@@ -281,7 +227,7 @@ final class Utils { ...@@ -281,7 +227,7 @@ final class Utils {
* @param size * @param size
* @return * @return
*/ */
public static List<Double> initializeListWithZeroValues( static List<Double> initializeListWithZeroValues(
final int size) { final int size) {
List<Double> list = new ArrayList<>(); List<Double> list = new ArrayList<>();
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment