diff --git a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java index a32c2072d5aa357817fc5598b1ba3f0f7231f0c8..77271a6fd85d9054c3fe4dfbc9329a6e70164acc 100644 --- a/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java +++ b/src/main/java/be/cylab/java/wowa/training/AbstractSolution.java @@ -13,7 +13,7 @@ public abstract class AbstractSolution protected List<Double> weights_w; protected List<Double> weights_p; - protected double fitness_score = Double.POSITIVE_INFINITY; + protected Double fitness_score = Double.POSITIVE_INFINITY; /** * SolutionDistance constructor. Needs weight number as parameter @@ -49,7 +49,8 @@ public abstract class AbstractSolution this.normalize(); } - abstract void computeScoreTo(List<double[]> data, double[] expected); + abstract void computeScoreTo(List<List<Double>> data, + List<Double> expected); /** * @return diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java index 6cadf111397ed79ea4e8c5fd948f1d9e2acc214d..5be9eac4a1177959c4c75800762ef08eae1b3d12 100644 --- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java +++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java @@ -41,18 +41,19 @@ public class SolutionDistance extends AbstractSolution { * @param data * @param expected */ - final void computeScoreTo(final List<double[]> data, - final double[] expected + final void computeScoreTo(final List<List<Double>> data, + final List<Double> expected ) { this.fitness_score = 0; for (int i = 0; i < data.size(); i++) { - double[] vector = data.get(i); - double target_value = expected[i]; + List<Double> vector = data.get(i); + Double target_value = expected.get(i); WOWA wowa = new WOWA( Utils.convertListDoubleToArrayDouble(this.weights_w), Utils.convertListDoubleToArrayDouble(this.weights_p)); - double aggregated_value = wowa.aggregate(vector); + Double aggregated_value = wowa.aggregate( + Utils.convertListDoubleToArrayDouble(vector)); this.fitness_score += Math.pow(target_value - aggregated_value, 2); } this.fitness_score = Math.sqrt(this.fitness_score); diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 093b669d5752d23ce5b74f847b2cefc5b7fced23..cf3f6a8ef4239b0bf481b49ffd53c2848271b0da 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -37,8 +37,8 @@ public class Trainer { public final AbstractSolution run( final String data_file_name, final String expected_file_name) { - List<double[]> data = Utils.convertJsonToDataForTrainer(data_file_name); - double[] expected + List<List<Double>> data = Utils.convertJsonToDataForTrainer(data_file_name); + List<Double> expected = Utils.convertJsonToExpectedForTrainer(expected_file_name); return this.run(data, expected); } @@ -49,8 +49,8 @@ public class Trainer { * @return */ public final AbstractSolution run( - final List<double[]> data, - final double[] expected) { + final List<List<Double>> data, + final List<Double> expected) { List<AbstractSolution> current_population = this.generateInitialPopulationAndComputeDistances( @@ -184,8 +184,8 @@ public class Trainer { */ final List<AbstractSolution> computeDistances( final List<AbstractSolution> solutions, - final List<double[]> data, - final double[] expected) { + final List<List<Double>> data, + final List<Double> expected) { for (AbstractSolution solution : solutions) { solution.computeScoreTo(data, expected); @@ -477,10 +477,10 @@ public class Trainer { */ private List<AbstractSolution> generateInitialPopulationAndComputeDistances( final int population_size, - final List<double[]> data, - final double[] expected) { + final List<List<Double>> data, + final List<Double> expected) { - int number_of_weights = data.get(0).length; + int number_of_weights = data.get(0).size(); List<AbstractSolution> initial_population = new ArrayList<>(); if (this.parameters.getPopulationInitializationMethod() == TrainerParameters.POPULATION_INITIALIZATION_RANDOM) { @@ -505,8 +505,8 @@ public class Trainer { */ final List<AbstractSolution> performReproduction( final List<AbstractSolution> population, - final List<double[]> data, - final double[] expected) { + final List<List<Double>> data, + final List<Double> expected) { List<AbstractSolution> parents = this.selectParents(population, this.getParameters().getNumberOfParents(), diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 1879e6fbda3448b08a7995c8031bd7f2d00870b6..a1c47e936d6799bfd732dc50799903bdcfbfdfb3 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -126,13 +126,13 @@ final class Utils { * @param filename * @return */ - public static List<double[]> convertJsonToDataForTrainer( + public static List<List<Double>> convertJsonToDataForTrainer( final String filename) { Genson genson = new Genson(); String data_json = Utils.readFileToString(filename); - List<double[]> data = genson.deserialize( + List<List<Double>> data = genson.deserialize( data_json, - new GenericType<List<double[]>>() { + new GenericType<List<List<Double>>>() { }); return data; } @@ -143,13 +143,14 @@ final class Utils { * @param filename * @return */ - public static double[] convertJsonToExpectedForTrainer( + public static List<Double> convertJsonToExpectedForTrainer( final String filename) { Genson genson = new Genson(); String expected_json = Utils.readFileToString(filename); - - double[] expected = genson.deserialize(expected_json, double[].class); + List<Double> expected = genson.deserialize(expected_json, + new GenericType<List<Double>>() { + }); return expected; }