Skip to content
Snippets Groups Projects
Commit 95898aec authored by a.croix's avatar a.croix
Browse files

Change weights from array to List

parent f8bdb290
No related branches found
No related tags found
No related merge requests found
Pipeline #1399 failed
package be.cylab.java.wowa.training;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
......@@ -10,8 +11,8 @@ import java.util.Random;
public abstract class AbstractSolution
implements Comparable<AbstractSolution>, Cloneable {
protected double[] weights_w;
protected double[] weights_p;
protected List<Double> weights_w;
protected List<Double> weights_p;
protected double fitness_score = Double.POSITIVE_INFINITY;
/**
......@@ -20,11 +21,11 @@ public abstract class AbstractSolution
* @param weights_number
*/
public AbstractSolution(final int weights_number) {
this.weights_w = new double[weights_number];
this.weights_p = new double[weights_number];
this.weights_w = new ArrayList<>();
this.weights_p = new ArrayList<>();
for (int i = 0; i < weights_number; i++) {
this.weights_w[i] = Math.random();
this.weights_p[i] = Math.random();
this.weights_w.add(Math.random());
this.weights_p.add(Math.random());
}
this.normalize();
}
......@@ -39,11 +40,11 @@ public abstract class AbstractSolution
public AbstractSolution(final int weights_number, final int seed) {
Random rnd = new Random(seed);
this.weights_w = new double[weights_number];
this.weights_p = new double[weights_number];
this.weights_w = new ArrayList<>();
this.weights_p = new ArrayList<>();
for (int i = 0; i < weights_number; i++) {
this.weights_w[i] = rnd.nextDouble();
this.weights_p[i] = rnd.nextDouble();
this.weights_w.add(rnd.nextDouble());
this.weights_p.add(rnd.nextDouble());
}
this.normalize();
}
......@@ -56,8 +57,8 @@ public abstract class AbstractSolution
@Override
public final String toString() {
return "SolutionDistance{"
+ "weights_w=" + Arrays.toString(weights_w)
+ ", weights_p=" + Arrays.toString(weights_p)
+ "weights_w=" + Arrays.toString(weights_w.toArray())
+ ", weights_p=" + Arrays.toString(weights_p.toArray())
+ ", fitness score=" + Math.abs(fitness_score)
+ '}';
}
......@@ -77,13 +78,13 @@ public abstract class AbstractSolution
//new weight value and position
double new_weight_value = Math.random();
int random_index = Utils.randomInteger(0, this.weights_p.length - 1);
int random_index = Utils.randomInteger(0, this.weights_p.size() - 1);
//Select w or p weights
int weight_selection = Utils.randomInteger(0, 1);
if (weight_selection == 0) {
this.weights_w[random_index] = new_weight_value;
this.weights_w.set(random_index, new_weight_value);
} else {
this.weights_p[random_index] = new_weight_value;
this.weights_p.set(random_index, new_weight_value);
}
this.normalize();
......@@ -133,28 +134,28 @@ public abstract class AbstractSolution
/**
* @return
*/
public final double[] getWeightsW() {
public final List<Double> getWeightsW() {
return this.weights_w;
}
/**
* @return
*/
public final double[] getWeightsP() {
public final List<Double> getWeightsP() {
return this.weights_p;
}
/**
* @param weights_w
*/
final void setWeightsW(final double[] weights_w) {
final void setWeightsW(final List<Double> weights_w) {
this.weights_w = weights_w;
}
/**
* @param weights_p
*/
final void setWeightsP(final double[] weights_p) {
final void setWeightsP(final List<Double> weights_p) {
this.weights_p = weights_p;
}
......
......@@ -49,7 +49,9 @@ public class SolutionDistance extends AbstractSolution {
for (int i = 0; i < data.size(); i++) {
double[] vector = data.get(i);
double target_value = expected[i];
WOWA wowa = new WOWA(this.weights_w, this.weights_p);
WOWA wowa = new WOWA(
Utils.convertListDoubleToArrayDouble(this.weights_w),
Utils.convertListDoubleToArrayDouble(this.weights_p));
double aggregated_value = wowa.aggregate(vector);
this.fitness_score += Math.pow(target_value - aggregated_value, 2);
}
......
package be.cylab.java.wowa.training;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
......@@ -130,12 +131,16 @@ public class Trainer {
int position_w = Utils.randomInteger(0, number_of_weight - 1);
int position_p = Utils.randomInteger(0, number_of_weight - 1);
if (random_memory[position_w][position_p] == 0) {
AbstractSolution solution = new SolutionDistance(5);
double[] weights_w = new double[number_of_weight];
weights_w[position_w] = 1;
double[] weights_p = new double[number_of_weight];
solution.setWeightsW(weights_w);
solution.setWeightsP(weights_p);
AbstractSolution solution
= this.factory.createSolutionObject(number_of_weight);
Double[] weights_w = new Double[number_of_weight];
weights_w[position_w] = 1.0;
Double[] weights_p = new Double[number_of_weight];
weights_p[position_p] = 1.0;
List<Double> w = Arrays.asList(weights_w);
List<Double> p = Arrays.asList(weights_p);
solution.setWeightsW(w);
solution.setWeightsP(p);
population.add(solution);
random_memory[position_w][position_p] = 1;
}
......@@ -294,7 +299,7 @@ public class Trainer {
final List<AbstractSolution> solutions) {
int original_solution_size = solutions.size();
int nbr_weights = solutions.get(0).getWeightsP().length;
int nbr_weights = solutions.get(0).getWeightsP().size();
//we add children to the current list of solutions
while (solutions.size() < this.parameters.getPopulationSize()) {
......@@ -339,45 +344,45 @@ public class Trainer {
final int cut_position,
final double beta) {
double p_new1_w = dad.getWeightsW()[cut_position] - beta
* (dad.getWeightsW()[cut_position]
- mom.getWeightsW()[cut_position]);
double p_new2_w = mom.getWeightsW()[cut_position] + beta
* (dad.getWeightsW()[cut_position]
- mom.getWeightsW()[cut_position]);
double p_new1_w = dad.getWeightsW().get(cut_position) - beta
* (dad.getWeightsW().get(cut_position)
- mom.getWeightsW().get(cut_position));
double p_new2_w = mom.getWeightsW().get(cut_position) + beta
* (dad.getWeightsW().get(cut_position)
- mom.getWeightsW().get(cut_position));
double p_new1_p = dad.getWeightsP()[cut_position] - beta
* (dad.getWeightsP()[cut_position]
- mom.getWeightsP()[cut_position]);
double p_new2_p = mom.getWeightsP()[cut_position] + beta
* (dad.getWeightsP()[cut_position]
- mom.getWeightsP()[cut_position]);
double p_new1_p = dad.getWeightsP().get(cut_position) - beta
* (dad.getWeightsP().get(cut_position)
- mom.getWeightsP().get(cut_position));
double p_new2_p = mom.getWeightsP().get(cut_position) + beta
* (dad.getWeightsP().get(cut_position)
- mom.getWeightsP().get(cut_position));
AbstractSolution child1
= this.factory.createSolutionObject(dad.weights_p.length);
= this.factory.createSolutionObject(dad.weights_p.size());
AbstractSolution child2
= this.factory.createSolutionObject(dad.weights_p.length);
= this.factory.createSolutionObject(dad.weights_p.size());
for (int i = 0; i < cut_position; i++) {
child1.getWeightsW()[i] = dad.getWeightsW()[i];
child1.getWeightsP()[i] = dad.getWeightsP()[i];
child1.getWeightsW().set(i, dad.getWeightsW().get(i));
child1.getWeightsP().set(i, dad.getWeightsP().get(i));
child2.getWeightsW()[i] = mom.getWeightsW()[i];
child2.getWeightsP()[i] = mom.getWeightsP()[i];
child2.getWeightsW().set(i, mom.getWeightsW().get(i));
child2.getWeightsP().set(i, mom.getWeightsP().get(i));
}
child1.getWeightsW()[cut_position] = p_new1_w;
child2.getWeightsW()[cut_position] = p_new2_w;
child1.getWeightsP()[cut_position] = p_new1_p;
child2.getWeightsP()[cut_position] = p_new2_p;
child1.getWeightsW().set(cut_position, p_new1_w);
child2.getWeightsW().set(cut_position, p_new2_w);
child1.getWeightsP().set(cut_position, p_new1_p);
child2.getWeightsP().set(cut_position, p_new2_p);
int nbre_weights = dad.getWeightsW().length;
int nbre_weights = dad.getWeightsW().size();
for (int i = cut_position + 1; i < nbre_weights; i++) {
child1.getWeightsW()[i] = mom.getWeightsW()[i];
child1.getWeightsP()[i] = mom.getWeightsP()[i];
child1.getWeightsW().set(i, mom.getWeightsW().get(i));
child1.getWeightsP().set(i, mom.getWeightsP().get(i));
child2.getWeightsW()[i] = dad.getWeightsW()[i];
child2.getWeightsP()[i] = dad.getWeightsP()[i];
child2.getWeightsW().set(i, dad.getWeightsW().get(i));
child2.getWeightsP().set(i, dad.getWeightsP().get(i));
}
//Check and correct only if we used a QuasiRandom generation
if (this.getParameters().getPopulationInitializationMethod()
......@@ -475,16 +480,18 @@ public class Trainer {
* @param child AbstractSolution
*/
private void checkAndCorrectNullWeightVector(final AbstractSolution child) {
if (Utils.sumArrayElements(child.getWeightsW()) == 0) {
int index = Utils.randomInteger(0, child.getWeightsW().length) - 1;
double[] weights_w = new double[5];
weights_w[index] = 1;
if (Utils.sumListElements(child.getWeightsW()) == 0) {
int index = Utils.randomInteger(0, child.getWeightsW().size()) - 1;
Double[] w = new Double[child.getWeightsP().size()];
w[index] = 1.0;
List<Double> weights_w = Arrays.asList(w);
child.setWeightsW(weights_w);
}
if (Utils.sumArrayElements(child.getWeightsP()) == 0) {
int index = Utils.randomInteger(0, child.getWeightsP().length) - 1;
double[] weights_p = new double[5];
weights_p[index] = 1;
if (Utils.sumListElements(child.getWeightsP()) == 0) {
int index = Utils.randomInteger(0, child.getWeightsP().size()) - 1;
Double[] p = new Double[child.getWeightsP().size()];
p[index] = 1.0;
List<Double> weights_p = Arrays.asList(p);
child.setWeightsP(weights_p);
}
}
......
......@@ -5,12 +5,14 @@ import be.cylab.java.roc.RocCoordinates;
import com.owlike.genson.GenericType;
import com.owlike.genson.Genson;
import info.debatty.java.aggregation.WOWA;
import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;
......@@ -26,20 +28,22 @@ public final class Utils {
/**
* Normalize a weights vector.
*
* @param weights
* @return
*/
public static double[] normalizeWeights(final double[] weights) {
double sum_weight = Utils.sumArrayElements(weights);
double[] weights_normalized = new double[weights.length];
for (int i = 0; i < weights.length; i++) {
weights_normalized[i] = weights[i] / sum_weight;
public static List<Double> normalizeWeights(final List<Double> weights) {
double sum_weight = Utils.sumListElements(weights);
List<Double> weights_normalized = new ArrayList<>();
for (int i = 0; i < weights.size(); i++) {
weights_normalized.add(weights.get(i) / sum_weight);
}
return weights_normalized;
}
/**
* Find the max distance in a List of SolutionDistance.
*
* @param solutions
* @return
*/
......@@ -56,6 +60,7 @@ public final class Utils {
/**
* Find the minimum distance in a List of SolutionDistance.
*
* @param solutions
* @return
*/
......@@ -72,6 +77,7 @@ public final class Utils {
/**
* Compute the sum of all distance in a List of SolutionDistance.
*
* @param solutions
* @return
*/
......@@ -86,12 +92,13 @@ public final class Utils {
/**
* Compute the sum of all array's element.
*
* @param array
* @return
*/
public static double sumArrayElements(final double[] array) {
public static double sumListElements(final List<Double> array) {
float sum = 0;
for (double weight : array) {
for (Double weight : array) {
sum += weight;
}
return sum;
......@@ -99,6 +106,7 @@ public final class Utils {
/**
* Generate a random integer in a range.
*
* @param min
* @param max
* @return
......@@ -113,6 +121,7 @@ public final class Utils {
/**
* Import a json file.
*
* @param filename
* @return
* @throws IOException if we cannot read file
......@@ -125,12 +134,14 @@ public final class Utils {
String data_json = new String(Files.readAllBytes(Paths.get(filename)));
List<double[]> data = genson.deserialize(
data_json,
new GenericType<List<double[]>>() { });
new GenericType<List<double[]>>() {
});
return data;
}
/**
* Import result file.
*
* @param filename
* @return
* @throws IOException if we cannot read file
......@@ -202,7 +213,11 @@ public final class Utils {
final AbstractSolution solution,
final List<double[]> data) {
double[] score = new double[data.size()];
WOWA wowa = new WOWA(solution.getWeightsW(), solution.getWeightsP());
double[] weights_w
= convertListDoubleToArrayDouble(solution.getWeightsW());
double[] weights_p
= convertListDoubleToArrayDouble(solution.getWeightsP());
WOWA wowa = new WOWA(weights_w, weights_p);
for (int i = 0; i < data.size(); i++) {
score[i] = wowa.aggregate(data.get(i));
}
......@@ -229,4 +244,12 @@ public final class Utils {
return true_alert;
}
public static double[] convertListDoubleToArrayDouble(
final List<Double> elements) {
Double[] w = new Double[elements.size()];
w = elements.toArray(w);
double[] ww = ArrayUtils.toPrimitive(w);
return ww;
}
}
......@@ -13,12 +13,12 @@ class SolutionDistanceTest {
@org.junit.jupiter.api.Test
public void testGenerateSolutionDistance() {
SolutionDistance solution = new SolutionDistance(5, 1234);
double actualW = Utils.sumArrayElements(solution.getWeightsW());
double actualP = Utils.sumArrayElements(solution.getWeightsP());
double actualW = Utils.sumListElements(solution.getWeightsW());
double actualP = Utils.sumListElements(solution.getWeightsP());
assertEquals(1, actualW, 0.00001);
assertEquals(1, actualP, 0.00001);
assertEquals(5, solution.getWeightsW().length);
assertEquals(5, solution.getWeightsP().length);
assertEquals(5, solution.getWeightsW().size());
assertEquals(5, solution.getWeightsP().size());
}
@org.junit.jupiter.api.Test
......@@ -35,8 +35,8 @@ class SolutionDistanceTest {
SolutionDistance solution = new SolutionDistance(5, 2548);
SolutionDistance copy = solution.clone();
solution.randomlyMutateWithProbability(1);
double actualW = Utils.sumArrayElements(solution.getWeightsW());
double actualP = Utils.sumArrayElements(solution.getWeightsP());
double actualW = Utils.sumListElements(solution.getWeightsW());
double actualP = Utils.sumListElements(solution.getWeightsP());
assertEquals(1, actualW, 0.00001);
assertEquals(1, actualP, 0.00001);
assertFalse(copy.equals(solution));
......
......@@ -3,10 +3,7 @@ package be.cylab.java.wowa.training;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.*;
import java.util.logging.Logger;
import static org.junit.jupiter.api.Assertions.*;
......@@ -44,8 +41,8 @@ class TrainerTest {
List<AbstractSolution> population = this.trainer.generateInitialPopulation(5, 100);
assertEquals(100, population.size());
for (AbstractSolution solution : population) {
assertEquals(5, solution.getWeightsW().length);
assertEquals(5, solution.getWeightsP().length);
assertEquals(5, solution.getWeightsW().size());
assertEquals(5, solution.getWeightsP().size());
assertEquals(Double.POSITIVE_INFINITY, solution.getFitnessScore());
}
}
......@@ -90,18 +87,27 @@ class TrainerTest {
population = this.trainer.computeDistances(population, data, expected);
List<AbstractSolution> parents = this.trainer.selectParents(population, 10, TrainerParameters.SELECTION_METHOD_RWS);
this.trainer.reproduce(parents.get(0), parents.get(1), parents, 1, 0.25875);
double[] expected_weight_w_child_1 = new double[] {0.29251897956852974, 0.177172118, 0.053456076807862664, 0.35802487078331857, 0.15845805675613056};
double[] expected_weight_w_child_2 = new double[] {0.28029560126515357, 0.159332272, 0.08953411511177484, 0.17126082447600055, 0.25994701067132736};
double[] expected_weight_p_child_1 = new double[] {0.16155217097863003, 0.215065784, 0.2817638841523853, 0.2017962998473193, 0.163708146070448};
double[] expected_weight_p_child_2 = new double[] {0.15632297412007345, 0.202921414, 0.20635412000064338, 0.2030854137628754, 0.20742982929073198};
List<Double> expected_weight_w_child_1 = Arrays.asList(new Double[]{0.29251897956852974, 0.177172118, 0.053456076807862664, 0.35802487078331857, 0.15845805675613056});
List<Double> expected_weight_w_child_2 = Arrays.asList(new Double[] {0.28029560126515357, 0.159332272, 0.08953411511177484, 0.17126082447600055, 0.25994701067132736});
List<Double> expected_weight_p_child_1 = Arrays.asList(new Double[] {0.16155217097863003, 0.215065784, 0.2817638841523853, 0.2017962998473193, 0.163708146070448});
List<Double> expected_weight_p_child_2 = Arrays.asList(new Double[] {0.15632297412007345, 0.202921414, 0.20635412000064338, 0.2030854137628754, 0.20742982929073198});
expected_weight_w_child_1 = Utils.normalizeWeights(expected_weight_w_child_1);
expected_weight_w_child_2 = Utils.normalizeWeights(expected_weight_w_child_2);
expected_weight_p_child_1 = Utils.normalizeWeights(expected_weight_p_child_1);
expected_weight_p_child_2 = Utils.normalizeWeights(expected_weight_p_child_2);
assertArrayEquals(parents.get(10).getWeightsW(), expected_weight_w_child_1, 0.0001);
assertArrayEquals(parents.get(10).getWeightsP(), expected_weight_p_child_1, 0.0001);
assertArrayEquals(parents.get(11).getWeightsW(), expected_weight_w_child_2, 0.0001);
assertArrayEquals(parents.get(11).getWeightsP(), expected_weight_p_child_2, 0.0001);
for (int i = 0; i < expected_weight_p_child_1.size(); i++) {
assertEquals(parents.get(10).getWeightsW().get(i), expected_weight_w_child_1.get(i), 0.00001);
assertEquals(parents.get(10).getWeightsP().get(i), expected_weight_p_child_1.get(i), 0.00001);
assertEquals(parents.get(11).getWeightsW().get(i), expected_weight_w_child_2.get(i), 0.00001);
assertEquals(parents.get(11).getWeightsP().get(i), expected_weight_p_child_2.get(i), 0.00001);
}
/*
assertTrue(parents.get(10).getWeightsW().equals(expected_weight_w_child_1));
assertTrue(parents.get(10).getWeightsP().equals(expected_weight_p_child_1));
assertTrue(parents.get(11).getWeightsW().equals(expected_weight_w_child_2));
assertTrue(parents.get(11).getWeightsP().equals(expected_weight_p_child_2));
*/
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment