From 8e142cf2b24b2fd012b637756748689ed201f2e1 Mon Sep 17 00:00:00 2001
From: Alex <croix.alexandre@gmail.com>
Date: Tue, 13 Aug 2019 18:16:41 +0200
Subject: [PATCH] Add new test for cross-validation

---
 .../be/cylab/java/wowa/training/Trainer.java  |  4 +-
 .../cylab/java/wowa/training/TrainerTest.java | 98 ++++++++++++++++++-
 2 files changed, 97 insertions(+), 5 deletions(-)

diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index 91fc9d3..dd394b9 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -158,7 +158,7 @@ public class Trainer {
      * @param expected
      * @param increase_ratio
      */
-    private TrainingDataset increaseTrueAlert(
+    TrainingDataset increaseTrueAlert(
             final List<List<Double>> data,
             final List<Double> expected,
             final int increase_ratio) {
@@ -181,7 +181,7 @@ public class Trainer {
      * @param fold_number
      * @return
      */
-    private List<TrainingDataset> prepareFolds(
+     List<TrainingDataset> prepareFolds(
             final TrainingDataset dataset,
             final int fold_number) {
         //List<List<Double>> data = dataset.getData();
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
index 03781fc..23ace5e 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
@@ -3,10 +3,12 @@ package be.cylab.java.wowa.training;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
-import java.util.*;
-import java.util.logging.Logger;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
 
-import static org.junit.jupiter.api.Assertions.*;
+import static org.junit.jupiter.api.Assertions.assertEquals;
 
 class TrainerTest {
 
@@ -22,6 +24,50 @@ class TrainerTest {
     void testRun() {
     }
 
+    @Test
+    void testIncreaseTrueAlert() {
+        List<List<Double>> data = generateData(100, 5);
+        int original_data_set_size = data.size();
+        List<Double> expected = generateExpectedBinaryClassification(100);
+        int increase_ration = 10;
+        double number_of_true_alert = Utils.sumListElements(expected);
+        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration);
+
+        //Check if the length of the dataset us correct
+        //increase_ration * number_of_true_alert + expected.size()
+        assertEquals(630, ds.getLength());
+
+        //Check the number of true_alert in the dataset
+        //increase_ration * number_of_true_alert + number_of_true_alert
+        assertEquals(583, (double)Utils.sumListElements(ds.getExpected()));
+
+        //Check if each rue alert elements are present (increase_ratio time) in he dataset
+        //Here, we check if each true alert element is present 10 more times than the original one
+        for (int i = 0; i < expected.size(); i++) {
+            if (ds.getExpected().get(i) == 1.0) {
+                int cnt = 0;
+                for (int j = original_data_set_size; j < ds.getLength(); j++) {
+                    if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
+                        cnt++;
+                    }
+                }
+                assertEquals(increase_ration, cnt);
+            }
+        }
+
+    }
+
+    @Test
+    void testPrepareFolds() {
+        List<List<Double>> data = generateData(100,5);
+        List<Double> expected = generateExpectedBinaryClassification(100);
+        int increase_ratio = 5;
+        int fold_number = 5;
+        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
+        List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
+        assertEquals(fold_number, folds.size());
+    }
+
     @Test
     void testFindBestSolution() {
         List<AbstractSolution> population = new ArrayList<>();
@@ -49,6 +95,39 @@ class TrainerTest {
 
     @Test
     void testComputeDistances() {
+        List<AbstractSolution> population = new ArrayList<>();
+        Random rnd = new Random(5484);
+        double[] assertion = {1.5355936654577846,
+                1.5671165947226067,
+                1.5640401176636098,
+                1.6622610655715315,
+                1.5848944096279156,
+                1.4772261163161986,
+                1.5591155520984579,
+                1.6280763047959521,
+                1.640136934670875,
+                1.5399006101316708,
+                1.523255224403363,
+                1.7676299421451587,
+                1.5102792713787483,
+                1.4568545985467831,
+                1.6160572671527558,
+                1.4822660627936635,
+                1.7361461131034035,
+                1.4686111015215462,
+                1.551779317992214,
+                1.4794365689675926};
+        for (int i = 0; i < 20; i++) {
+            AbstractSolution solution = new SolutionDistance(5, rnd.nextInt());
+            population.add(solution);
+        }
+        List<List<Double>> data = generateData(20, 5);
+        List<Double> expected = generateExpected(20);
+        List<AbstractSolution> computed_population = trainer.computeDistances(population, data, expected);
+        for (int i = 0; i < computed_population.size(); i++) {
+            assertEquals(assertion[i], computed_population.get(i).getFitnessScore());
+        }
+
 
     }
 
@@ -137,4 +216,17 @@ class TrainerTest {
         }
         return expected;
     }
+
+    static List<Double> generateExpectedBinaryClassification(final int size) {
+        Random rnd = new Random(5768);
+        List<Double> expected = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            if (rnd.nextDouble() <= 0.5) {
+                expected.add(new Double(0.0));
+            } else {
+                expected.add(new Double(1.0));
+            }
+        }
+        return  expected;
+    }
 }
\ No newline at end of file
-- 
GitLab