From 8e142cf2b24b2fd012b637756748689ed201f2e1 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Tue, 13 Aug 2019 18:16:41 +0200 Subject: [PATCH] Add new test for cross-validation --- .../be/cylab/java/wowa/training/Trainer.java | 4 +- .../cylab/java/wowa/training/TrainerTest.java | 98 ++++++++++++++++++- 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 91fc9d3..dd394b9 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -158,7 +158,7 @@ public class Trainer { * @param expected * @param increase_ratio */ - private TrainingDataset increaseTrueAlert( + TrainingDataset increaseTrueAlert( final List<List<Double>> data, final List<Double> expected, final int increase_ratio) { @@ -181,7 +181,7 @@ public class Trainer { * @param fold_number * @return */ - private List<TrainingDataset> prepareFolds( + List<TrainingDataset> prepareFolds( final TrainingDataset dataset, final int fold_number) { //List<List<Double>> data = dataset.getData(); diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java index 03781fc..23ace5e 100644 --- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java +++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java @@ -3,10 +3,12 @@ package be.cylab.java.wowa.training; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.*; -import java.util.logging.Logger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; class TrainerTest { @@ -22,6 +24,50 @@ class TrainerTest { void testRun() { } + @Test + void testIncreaseTrueAlert() { + List<List<Double>> data = generateData(100, 5); + int original_data_set_size = data.size(); + List<Double> expected = generateExpectedBinaryClassification(100); + int increase_ration = 10; + double number_of_true_alert = Utils.sumListElements(expected); + TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration); + + //Check if the length of the dataset us correct + //increase_ration * number_of_true_alert + expected.size() + assertEquals(630, ds.getLength()); + + //Check the number of true_alert in the dataset + //increase_ration * number_of_true_alert + number_of_true_alert + assertEquals(583, (double)Utils.sumListElements(ds.getExpected())); + + //Check if each rue alert elements are present (increase_ratio time) in he dataset + //Here, we check if each true alert element is present 10 more times than the original one + for (int i = 0; i < expected.size(); i++) { + if (ds.getExpected().get(i) == 1.0) { + int cnt = 0; + for (int j = original_data_set_size; j < ds.getLength(); j++) { + if (ds.getExpected().get(i) == ds.getExpected().get(j)) { + cnt++; + } + } + assertEquals(increase_ration, cnt); + } + } + + } + + @Test + void testPrepareFolds() { + List<List<Double>> data = generateData(100,5); + List<Double> expected = generateExpectedBinaryClassification(100); + int increase_ratio = 5; + int fold_number = 5; + TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio); + List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number); + assertEquals(fold_number, folds.size()); + } + @Test void testFindBestSolution() { List<AbstractSolution> population = new ArrayList<>(); @@ -49,6 +95,39 @@ class TrainerTest { @Test void testComputeDistances() { + List<AbstractSolution> population = new ArrayList<>(); + Random rnd = new Random(5484); + double[] assertion = {1.5355936654577846, + 1.5671165947226067, + 1.5640401176636098, + 1.6622610655715315, + 1.5848944096279156, + 1.4772261163161986, + 1.5591155520984579, + 1.6280763047959521, + 1.640136934670875, + 1.5399006101316708, + 1.523255224403363, + 1.7676299421451587, + 1.5102792713787483, + 1.4568545985467831, + 1.6160572671527558, + 1.4822660627936635, + 1.7361461131034035, + 1.4686111015215462, + 1.551779317992214, + 1.4794365689675926}; + for (int i = 0; i < 20; i++) { + AbstractSolution solution = new SolutionDistance(5, rnd.nextInt()); + population.add(solution); + } + List<List<Double>> data = generateData(20, 5); + List<Double> expected = generateExpected(20); + List<AbstractSolution> computed_population = trainer.computeDistances(population, data, expected); + for (int i = 0; i < computed_population.size(); i++) { + assertEquals(assertion[i], computed_population.get(i).getFitnessScore()); + } + } @@ -137,4 +216,17 @@ class TrainerTest { } return expected; } + + static List<Double> generateExpectedBinaryClassification(final int size) { + Random rnd = new Random(5768); + List<Double> expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + if (rnd.nextDouble() <= 0.5) { + expected.add(new Double(0.0)); + } else { + expected.add(new Double(1.0)); + } + } + return expected; + } } \ No newline at end of file -- GitLab