Skip to content
Snippets Groups Projects
Commit 8e142cf2 authored by a.croix's avatar a.croix
Browse files

Add new test for cross-validation

parent 74575555
No related branches found
No related tags found
No related merge requests found
Pipeline #2117 failed
...@@ -158,7 +158,7 @@ public class Trainer { ...@@ -158,7 +158,7 @@ public class Trainer {
* @param expected * @param expected
* @param increase_ratio * @param increase_ratio
*/ */
private TrainingDataset increaseTrueAlert( TrainingDataset increaseTrueAlert(
final List<List<Double>> data, final List<List<Double>> data,
final List<Double> expected, final List<Double> expected,
final int increase_ratio) { final int increase_ratio) {
...@@ -181,7 +181,7 @@ public class Trainer { ...@@ -181,7 +181,7 @@ public class Trainer {
* @param fold_number * @param fold_number
* @return * @return
*/ */
private List<TrainingDataset> prepareFolds( List<TrainingDataset> prepareFolds(
final TrainingDataset dataset, final TrainingDataset dataset,
final int fold_number) { final int fold_number) {
//List<List<Double>> data = dataset.getData(); //List<List<Double>> data = dataset.getData();
......
...@@ -3,10 +3,12 @@ package be.cylab.java.wowa.training; ...@@ -3,10 +3,12 @@ package be.cylab.java.wowa.training;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.*; import java.util.ArrayList;
import java.util.logging.Logger; import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
class TrainerTest { class TrainerTest {
...@@ -22,6 +24,50 @@ class TrainerTest { ...@@ -22,6 +24,50 @@ class TrainerTest {
void testRun() { void testRun() {
} }
@Test
void testIncreaseTrueAlert() {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ration = 10;
double number_of_true_alert = Utils.sumListElements(expected);
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(630, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(583, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = original_data_set_size; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ration, cnt);
}
}
}
@Test
void testPrepareFolds() {
List<List<Double>> data = generateData(100,5);
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 5;
int fold_number = 5;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
assertEquals(fold_number, folds.size());
}
@Test @Test
void testFindBestSolution() { void testFindBestSolution() {
List<AbstractSolution> population = new ArrayList<>(); List<AbstractSolution> population = new ArrayList<>();
...@@ -49,6 +95,39 @@ class TrainerTest { ...@@ -49,6 +95,39 @@ class TrainerTest {
@Test @Test
void testComputeDistances() { void testComputeDistances() {
List<AbstractSolution> population = new ArrayList<>();
Random rnd = new Random(5484);
double[] assertion = {1.5355936654577846,
1.5671165947226067,
1.5640401176636098,
1.6622610655715315,
1.5848944096279156,
1.4772261163161986,
1.5591155520984579,
1.6280763047959521,
1.640136934670875,
1.5399006101316708,
1.523255224403363,
1.7676299421451587,
1.5102792713787483,
1.4568545985467831,
1.6160572671527558,
1.4822660627936635,
1.7361461131034035,
1.4686111015215462,
1.551779317992214,
1.4794365689675926};
for (int i = 0; i < 20; i++) {
AbstractSolution solution = new SolutionDistance(5, rnd.nextInt());
population.add(solution);
}
List<List<Double>> data = generateData(20, 5);
List<Double> expected = generateExpected(20);
List<AbstractSolution> computed_population = trainer.computeDistances(population, data, expected);
for (int i = 0; i < computed_population.size(); i++) {
assertEquals(assertion[i], computed_population.get(i).getFitnessScore());
}
} }
...@@ -137,4 +216,17 @@ class TrainerTest { ...@@ -137,4 +216,17 @@ class TrainerTest {
} }
return expected; return expected;
} }
static List<Double> generateExpectedBinaryClassification(final int size) {
Random rnd = new Random(5768);
List<Double> expected = new ArrayList<>();
for (int i = 0; i < size; i++) {
if (rnd.nextDouble() <= 0.5) {
expected.add(new Double(0.0));
} else {
expected.add(new Double(1.0));
}
}
return expected;
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment