Skip to content
Snippets Groups Projects
Commit 8868a85e authored by a.croix's avatar a.croix
Browse files

Add last unit tests + improve prepareFolds method (Fix #5)

parent c57f43e1
No related branches found
No related tags found
No related merge requests found
Pipeline #2162 passed
......@@ -166,7 +166,7 @@ public class Trainer {
int data_size = expected.size();
for (int i = 0; i < data_size; i++) {
if (expected.get(i) == 1) {
for (int j = 0; j < increase_ratio; j++) {
for (int j = 0; j < increase_ratio - 1; j++) {
expected.add(expected.get(i));
data.add(data.get(i));
}
......@@ -215,6 +215,17 @@ public class Trainer {
}
fold_dataset.add(tmp);
}
int fold_counter = 0;
while (dataset.getLength() > 0) {
int i = Utils.randomInteger(0, dataset.getLength() - 1);
fold_dataset.get(fold_counter).addElementInDataset(dataset, i);
dataset.removeElementInDataset(i);
if (fold_counter == fold_dataset.size() - 1) {
fold_counter = 0;
} else {
fold_counter++;
}
}
return fold_dataset;
}
......
......@@ -111,8 +111,10 @@ final class Utils {
* @return
*/
static int randomInteger(final int min, final int max) {
if (min >= max) {
if (min > max) {
throw new IllegalArgumentException("Max must be greater then min");
} else if (min == max) {
return min;
}
Random rnd = new Random();
return rnd.nextInt((max - min) + 1) + min;
......
......@@ -27,29 +27,30 @@ class TrainerTest {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ration = 10;
double number_of_true_alert = Utils.sumListElements(expected);
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration);
int increase_ratio = 10;
int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = original_data_set_size - number_of_true_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(630, ds.getLength());
assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(583, (double)Utils.sumListElements(ds.getExpected()));
assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = original_data_set_size; j < ds.getLength(); j++) {
for (int j = 0; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ration, cnt);
assertEquals(increase_ratio, cnt);
}
}
......@@ -57,13 +58,29 @@ class TrainerTest {
@Test
void testPrepareFolds() {
List<List<Double>> data = generateData(100,5);
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 5;
int fold_number = 5;
int number_of_elements = 100;
List<List<Double>> data = generateData(number_of_elements,5);
List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
int increase_ratio = 3;
int fold_number = 10;
int number_of_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = number_of_elements - number_of_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
assertEquals(fold_number, folds.size());
for (int i = 0; i < folds.size(); i++) {
assertTrue(folds.get(i).getLength()
== (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number || folds.get(i).getLength()
== 1 + (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number);
assertTrue(Utils.sumListElements(folds.get(i).getExpected())
== ((number_of_alert * increase_ratio) / fold_number) ||
Utils.sumListElements(folds.get(i).getExpected())
== 1 + ((number_of_alert * increase_ratio) / fold_number)
|| Utils.sumListElements(folds.get(i).getExpected())
== 2 + ((number_of_alert * increase_ratio) / fold_number));
}
}
@Test
......
......@@ -162,8 +162,8 @@ class UtilsTest {
int a = Utils.randomInteger(0, 500);
assertTrue(a > 0);
assertTrue(a < 500);
assertEquals(45, Utils.randomInteger(45, 45));
assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 120);});
assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 125);});
int b = Utils.randomInteger(45,46);
assertTrue(b == 45 || b == 46);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment