diff --git a/pom.xml b/pom.xml index 218428fdf8d2b3ad41a4a484cc4d4ec629737a66..5bc24afab4e38d08519aaaa1452144c26d1a16d0 100644 --- a/pom.xml +++ b/pom.xml @@ -24,6 +24,13 @@ <version>5.3.1</version> <scope>test</scope> </dependency> + <!-- https://mvnrepository.com/artifact/info.debatty/java-aggregation --> + <dependency> + <groupId>info.debatty</groupId> + <artifactId>java-aggregation</artifactId> + <version>0.2</version> + </dependency> + </dependencies> <build> <plugins> diff --git a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java index 34fc75e3fa6e3c3bc55ece0fd2c227576114c535..a42bf8f485534dd5c707bf4fc1f5efa14f77db2b 100644 --- a/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java +++ b/src/main/java/be/cylab/java/wowa/training/SolutionDistance.java @@ -1,16 +1,27 @@ package be.cylab.java.wowa.training; +import info.debatty.java.aggregation.WOWA; + public class SolutionDistance { private double[] weights_w; private double[] weights_p; private double distance = Double.POSITIVE_INFINITY; - public float computeScoreTo() { - return 0; + public void computeScoreTo(double[][] data, double[] expected) { + this.distance = 0; + for (int i = 0; i < data.length; i++) { + double[] vector = data[i]; + double target_value = expected[i]; + WOWA wowa = new WOWA(this.weights_w, this.weights_p); + double aggregated_value = wowa.aggregates(vector); + this.distance += Math.pow(target_value - aggregated_value, 2); + } + this.distance = Math.sqrt(this.distance); + } - public void randomlyMutateWithProbability(float probability) { + public void randomlyMutateWithProbability(double probability) { } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index f7fe21c1195ee05179626b5fd1036ceaeffe6223..982831de662b37729f5caf447b894a72d7c71f97 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -8,7 +8,7 @@ public class Trainer { this.parameters = parameters; } - public SolutionDistance run(float[][] data, float[] expected) { + public SolutionDistance run(double[][] data, double[] expected) { } @@ -24,7 +24,7 @@ public class Trainer { return 0; } - public SolutionDistance[] computeDistances(SolutionDistance[] solutions, float[][] data, float[] expected) { + public SolutionDistance[] computeDistances(SolutionDistance[] solutions, double[][] data, double[] expected) { } @@ -56,7 +56,7 @@ public class Trainer { return this.parameters; } - public SolutionDistance[] generateInitialPopulationAndComputeDistances(int populationSize, float[][] data, float[] expected) { + public SolutionDistance[] generateInitialPopulationAndComputeDistances(int populationSize, double[][] data, double[] expected) { }