diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index 734fdf1834c75ed8e13ba2b98fcb3818f73a375c..01ff230160e5c5146e702b8f4c128f05b6b50000 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -166,7 +166,7 @@ public class Trainer {
         int data_size = expected.size();
         for (int i = 0; i < data_size; i++) {
             if (expected.get(i) == 1) {
-                for (int j = 0; j < increase_ratio; j++) {
+                for (int j = 0; j < increase_ratio - 1; j++) {
                     expected.add(expected.get(i));
                     data.add(data.get(i));
                 }
@@ -215,6 +215,17 @@ public class Trainer {
             }
             fold_dataset.add(tmp);
         }
+        int fold_counter = 0;
+        while (dataset.getLength() > 0) {
+            int i = Utils.randomInteger(0, dataset.getLength() - 1);
+            fold_dataset.get(fold_counter).addElementInDataset(dataset, i);
+            dataset.removeElementInDataset(i);
+            if (fold_counter == fold_dataset.size() - 1) {
+                fold_counter = 0;
+            } else {
+                fold_counter++;
+            }
+        }
         return fold_dataset;
     }
 
diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java
index 52d1c8c28ff38d9858f67d1bb0dbba01589891b2..362b56ec110d0910d96e3554caea837ae2975ee3 100644
--- a/src/main/java/be/cylab/java/wowa/training/Utils.java
+++ b/src/main/java/be/cylab/java/wowa/training/Utils.java
@@ -111,8 +111,10 @@ final class Utils {
      * @return
      */
     static int randomInteger(final int min, final int max) {
-        if (min >= max) {
+        if (min > max) {
             throw new IllegalArgumentException("Max must be greater then min");
+        } else if (min == max) {
+            return min;
         }
         Random rnd = new Random();
         return rnd.nextInt((max - min) + 1) + min;
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
index bce9edb6c521cb674ac64511bef4deaae76209ff..fd5d815f5260415d4c216c6a49134e66f358bf64 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
@@ -27,29 +27,30 @@ class TrainerTest {
         List<List<Double>> data = generateData(100, 5);
         int original_data_set_size = data.size();
         List<Double> expected = generateExpectedBinaryClassification(100);
-        int increase_ration = 10;
-        double number_of_true_alert = Utils.sumListElements(expected);
-        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ration);
+        int increase_ratio = 10;
+        int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = original_data_set_size - number_of_true_alert;
+        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
 
         //Check if the length of the dataset us correct
         //increase_ration * number_of_true_alert + expected.size()
-        assertEquals(630, ds.getLength());
+        assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
 
         //Check the number of true_alert in the dataset
         //increase_ration * number_of_true_alert + number_of_true_alert
-        assertEquals(583, (double)Utils.sumListElements(ds.getExpected()));
+        assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
 
         //Check if each rue alert elements are present (increase_ratio time) in he dataset
         //Here, we check if each true alert element is present 10 more times than the original one
         for (int i = 0; i < expected.size(); i++) {
             if (ds.getExpected().get(i) == 1.0) {
                 int cnt = 0;
-                for (int j = original_data_set_size; j < ds.getLength(); j++) {
+                for (int j = 0; j < ds.getLength(); j++) {
                     if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
                         cnt++;
                     }
                 }
-                assertEquals(increase_ration, cnt);
+                assertEquals(increase_ratio, cnt);
             }
         }
 
@@ -57,13 +58,29 @@ class TrainerTest {
 
     @Test
     void testPrepareFolds() {
-        List<List<Double>> data = generateData(100,5);
-        List<Double> expected = generateExpectedBinaryClassification(100);
-        int increase_ratio = 5;
-        int fold_number = 5;
+        int number_of_elements = 100;
+        List<List<Double>> data = generateData(number_of_elements,5);
+        List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
+        int increase_ratio = 3;
+        int fold_number = 10;
+        int number_of_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = number_of_elements - number_of_alert;
         TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
         List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
         assertEquals(fold_number, folds.size());
+        for (int i = 0; i < folds.size(); i++) {
+            assertTrue(folds.get(i).getLength()
+                    == (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number || folds.get(i).getLength()
+                    == 1 + (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number);
+            assertTrue(Utils.sumListElements(folds.get(i).getExpected())
+                    == ((number_of_alert * increase_ratio) / fold_number) ||
+                    Utils.sumListElements(folds.get(i).getExpected())
+                            == 1 + ((number_of_alert * increase_ratio) / fold_number)
+                    || Utils.sumListElements(folds.get(i).getExpected())
+                    == 2 + ((number_of_alert * increase_ratio) / fold_number));
+        }
     }
 
     @Test
diff --git a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java
index 0f41032d3946fb0900ab68a6450d678f4cb530ce..57b7920d97083f6fe536a7590894db2e2b60bb90 100644
--- a/src/test/java/be/cylab/java/wowa/training/UtilsTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/UtilsTest.java
@@ -162,8 +162,8 @@ class UtilsTest {
         int a = Utils.randomInteger(0, 500);
         assertTrue(a > 0);
         assertTrue(a < 500);
+        assertEquals(45, Utils.randomInteger(45, 45));
         assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 120);});
-        assertThrows(IllegalArgumentException.class, () -> {Utils.randomInteger(125, 125);});
         int b = Utils.randomInteger(45,46);
         assertTrue(b == 45 || b == 46);