diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 734fdf1834c75ed8e13ba2b98fcb3818f73a375c..01ff230160e5c5146e702b8f4c128f05b6b50000 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -166,7 +166,7 @@ public class Trainer { int data_size = expected.size(); for (int i = 0; i < data_size; i++) { if (expected.get(i) == 1) { - for (int j = 0; j < increase_ratio; j++) { + for (int j = 0; j < increase_ratio - 1; j++) { expected.add(expected.get(i)); data.add(data.get(i)); } @@ -215,6 +215,17 @@ public class Trainer { } fold_dataset.add(tmp); } + int fold_counter = 0; + while (dataset.getLength() > 0) { + int i = Utils.randomInteger(0, dataset.getLength() - 1); + fold_dataset.get(fold_counter).addElementInDataset(dataset, i); + dataset.removeElementInDataset(i); + if (fold_counter == fold_dataset.size() - 1) { + fold_counter = 0; + } else { + fold_counter++; + } + } return fold_dataset; } diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 52d1c8c28ff38d9858f67d1bb0dbba01589891b2..362b56ec110d0910d96e3554caea837ae2975ee3 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -111,8 +111,10 @@ final class Utils { * @return */ static int randomInteger(final int min, final int max) { - if (min >= max) { + if (min > max) { throw new IllegalArgumentException("Max must be greater then min"); + } else if (min == max) { + return min; } Random rnd = new Random(); return rnd.nextInt((max - min) + 1) + min; diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java index bce9edb6c521cb674ac64511bef4deaae76209ff..fd5d815f5260415d4c216c6a49134e66f358bf64 100644 --- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java +++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java @@ -27,29 +27,30 @@ class TrainerTest { List<List<Double>> data = generateData(100, 5); int original_data_set_size = data.size(); List<Double> expected = generateExpectedBinaryClassification(100); - int increase_ration = 10; - double number_of_true_alert = Utils.sumListElements(expected); - TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration); + int increase_ratio = 10; + int number_of_true_alert = (int)(double)Utils.sumListElements(expected); + int number_of_no_alert = original_data_set_size - number_of_true_alert; + TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio); //Check if the length of the dataset us correct //increase_ration * number_of_true_alert + expected.size() - assertEquals(630, ds.getLength()); + assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength()); //Check the number of true_alert in the dataset //increase_ration * number_of_true_alert + number_of_true_alert - assertEquals(583, (double)Utils.sumListElements(ds.getExpected())); + assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected())); //Check if each rue alert elements are present (increase_ratio time) in he dataset //Here, we check if each true alert element is present 10 more times than the original one for (int i = 0; i < expected.size(); i++) { if (ds.getExpected().get(i) == 1.0) { int cnt = 0; - for (int j = original_data_set_size; j < ds.getLength(); j++) { + for (int j = 0; j < ds.getLength(); j++) { if (ds.getExpected().get(i) == ds.getExpected().get(j)) { cnt++; } } - assertEquals(increase_ration, cnt); + assertEquals(increase_ratio, cnt); } } @@ -57,13 +58,29 @@ class TrainerTest { @Test void testPrepareFolds() { - List<List<Double>> data = generateData(100,5); - List<Double> expected = generateExpectedBinaryClassification(100); - int increase_ratio = 5; - int fold_number = 5; + int number_of_elements = 100; + List<List<Double>> data = generateData(number_of_elements,5); + List<Double> expected = generateExpectedBinaryClassification(number_of_elements); + int increase_ratio = 3; + int fold_number = 10; + int number_of_alert = (int)(double)Utils.sumListElements(expected); + int number_of_no_alert = number_of_elements - number_of_alert; TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio); List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number); assertEquals(fold_number, folds.size()); + for (int i = 0; i < folds.size(); i++) { + assertTrue(folds.get(i).getLength() + == (number_of_alert * increase_ratio + number_of_no_alert) + / fold_number || folds.get(i).getLength() + == 1 + (number_of_alert * increase_ratio + number_of_no_alert) + / fold_number); + assertTrue(Utils.sumListElements(folds.get(i).getExpected()) + == ((number_of_alert * increase_ratio) / fold_number) || + Utils.sumListElements(folds.get(i).getExpected()) + == 1 + ((number_of_alert * increase_ratio) / fold_number) + || Utils.sumListElements(folds.get(i).getExpected()) + == 2 + ((number_of_alert * increase_ratio) / fold_number)); + } } @Test diff --git a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java index 0f41032d3946fb0900ab68a6450d678f4cb530ce..57b7920d97083f6fe536a7590894db2e2b60bb90 100644 --- a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java +++ b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java @@ -162,8 +162,8 @@ class UtilsTest { int a = Utils.randomInteger(0, 500); assertTrue(a > 0); assertTrue(a < 500); + assertEquals(45, Utils.randomInteger(45, 45)); assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 120);}); - assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 125);}); int b = Utils.randomInteger(45,46); assertTrue(b == 45 || b == 46);