diff --git a/java-wowa-training.iml b/java-wowa-training.iml
index 08c09f61200f2f2388e08bc4ac406e511612c382..e65012bc30b00fe8f9ce599d0af7a8d7f858c371 100644
--- a/java-wowa-training.iml
+++ b/java-wowa-training.iml
@@ -42,13 +42,15 @@
     <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" />
     <orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" />
-    <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" />
+    <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" />
     <orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" />
     <orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" />
     <orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" />
     <orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" />
+    <orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" />
+    <orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" />
     <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
diff --git a/src/main/java/be/cylab/java/wowa/training/HyperParameters.java b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java
new file mode 100644
index 0000000000000000000000000000000000000000..9915e6e3f926f7bb17b4cfdb4290aefbf5c0be8a
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java
@@ -0,0 +1,140 @@
+package be.cylab.java.wowa.training;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.nd4j.linalg.activations.Activation;
+
+/**
+ * Parameters of the NeuralNetwork class.
+ */
+public class HyperParameters {
+    private int neurons_number;
+    private double learning_rate;
+    private OptimizationAlgorithm algorithm;
+    private Activation activation_function;
+    private double percent_test_train;
+
+    /**
+     * Default constructor.
+     * @param neurons_number
+     * @param learning_rate
+     * @param algorithm
+     * @param activation_function
+     * @param percent_test_train
+     */
+    public HyperParameters(
+            final int neurons_number,
+            final double learning_rate,
+            final OptimizationAlgorithm algorithm,
+            final Activation activation_function,
+            final double percent_test_train) {
+        setNeuronsNumber(neurons_number);
+        setLearningRate(learning_rate);
+        setAlgorithm(algorithm);
+        setActivationFunction(activation_function);
+        setPercentTestTrain(percent_test_train);
+    }
+
+    /**
+     * Constructor without percent_train_test value.
+     * @param neurons_number
+     * @param learning_rate
+     * @param algorithm
+     * @param activation_function
+     */
+    public HyperParameters(
+            final int neurons_number,
+            final double learning_rate,
+            final OptimizationAlgorithm algorithm,
+            final Activation activation_function) {
+        this(neurons_number, learning_rate, algorithm,
+                activation_function, 99);
+    }
+
+    /**
+     * Getter for neuron_number.
+     * @return
+     */
+    public final int getNeuronsNumber() {
+        return neurons_number;
+    }
+
+    /**
+     * Getter for learning rate.
+     * @return
+     */
+    public final double getLearningRate() {
+        return learning_rate;
+    }
+
+    /**
+     * Getter for backpropagation algorithm.
+     * @return
+     */
+    public final OptimizationAlgorithm getAlgorithm() {
+        return algorithm;
+    }
+
+    /**
+     * Getter for activation function.
+     * @return
+     */
+    public final Activation getActivationFunction() {
+        return activation_function;
+    }
+
+    /**
+     * Getter for percent_test_train.
+     * @return
+     */
+    public final double getPercentTestTrain() {
+        return percent_test_train;
+    }
+
+    /**
+     * @param neurons_number
+     */
+    public final void setNeuronsNumber(final int neurons_number) {
+        if (neurons_number < 5 || neurons_number > 80) {
+            throw new IllegalArgumentException(
+                    "Neuron number must be between 5 and 80");
+        }
+        this.neurons_number = neurons_number;
+    }
+
+    /**
+     * @param learning_rate
+     */
+    public final void setLearningRate(final double learning_rate) {
+        if (learning_rate <= 0.0 || learning_rate >= 1.0) {
+            throw new IllegalArgumentException(
+                    "Learning rate must be between 0 and 1");
+        }
+        this.learning_rate = learning_rate;
+    }
+
+    /**
+     * @param algorithm
+     */
+    public final void setAlgorithm(final OptimizationAlgorithm algorithm) {
+        this.algorithm = algorithm;
+    }
+
+    /**
+     * @param activation_function
+     */
+    public final void setActivationFunction(
+            final Activation activation_function) {
+        this.activation_function = activation_function;
+    }
+
+    /**
+     * @param percent_test_train
+     */
+    public final void setPercentTestTrain(final double percent_test_train) {
+        if (percent_test_train <= 0 || percent_test_train >= 100) {
+            throw new IllegalArgumentException(
+                    "Percentage of train must be between 0 and 100");
+        }
+        this.percent_test_train = percent_test_train;
+    }
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java
deleted file mode 100644
index 7328913d5b67cd2845395633e1ffff9b547b04eb..0000000000000000000000000000000000000000
--- a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java
+++ /dev/null
@@ -1,130 +0,0 @@
-package be.cylab.java.wowa.training;
-
-import org.datavec.api.records.reader.RecordReader;
-import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
-import org.datavec.api.split.FileSplit;
-import org.datavec.api.util.ClassPathResource;
-import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
-import org.deeplearning4j.eval.Evaluation;
-import org.deeplearning4j.eval.ROC;
-import org.deeplearning4j.nn.api.OptimizationAlgorithm;
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-import org.deeplearning4j.nn.conf.layers.DenseLayer;
-import org.deeplearning4j.nn.conf.layers.OutputLayer;
-import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
-import org.deeplearning4j.nn.weights.WeightInit;
-import org.nd4j.linalg.activations.Activation;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.dataset.SplitTestAndTrain;
-import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
-import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
-import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
-import org.nd4j.linalg.lossfunctions.LossFunctions;
-
-import java.io.IOException;
-
-/**
- * Class to test neural network learning.
- */
-
-public final class LearningNeuralNetwork {
-
-    /**
-     * Class count.
-     */
-    public static final int CLASSES_COUNT = 2;
-    /**
-     * Features count.
-     */
-    public static final int FEATURES_COUNT = 5;
-
-
-    private DataSet training_data;
-    private DataSet testing_data;
-    private double learning_rate;
-    private OptimizationAlgorithm algorithm;
-    private Activation activation_function;
-
-
-    /**
-     * Constructor.
-     *
-     * @param file_name
-     */
-    public LearningNeuralNetwork(final String file_name,
-                                 final double learning_rate,
-                                 final OptimizationAlgorithm algorithm,
-                                 final Activation activation_function
-    ) {
-        try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
-            record_reader.initialize(new FileSplit(
-                    new ClassPathResource(file_name).getFile()
-            ));
-            DataSetIterator iterator = new RecordReaderDataSetIterator(
-                    record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
-            DataSet all_data = iterator.next();
-            all_data.shuffle();
-            DataNormalization normalizer = new NormalizerStandardize();
-            normalizer.fit(all_data);
-            normalizer.transform(all_data);
-
-            SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.90);
-            this.training_data = test_and_train.getTrain();
-            this.testing_data = test_and_train.getTest();
-
-        } catch (IOException e) {
-            e.printStackTrace();
-        } catch (InterruptedException e) {
-            e.printStackTrace();
-        }
-        this.learning_rate = learning_rate;
-        this.algorithm = algorithm;
-        this.activation_function = activation_function;
-    }
-
-    /**
-     * Method to build a simple neural network, train and test it.
-     *
-     * @param neurons_number
-     */
-    public void learning(final int neurons_number) {
-
-        MultiLayerConfiguration configuration
-                = new NeuralNetConfiguration.Builder()
-                .iterations(1000)
-                .activation(this.activation_function)
-                .optimizationAlgo(
-                        this.algorithm)
-                .weightInit(WeightInit.XAVIER)
-                .learningRate(this.learning_rate)
-                .regularization(true).l2(0.0001)
-                .list()
-                .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT)
-                        .nOut(neurons_number).build())
-                .layer(1, new DenseLayer.Builder().nIn(neurons_number)
-                        .nOut(neurons_number).build())
-                .layer(2, new OutputLayer.Builder(LossFunctions
-                        .LossFunction.NEGATIVELOGLIKELIHOOD)
-                        .activation(Activation.SOFTMAX)
-                        .nIn(neurons_number)
-                        .nOut(CLASSES_COUNT).build())
-                .backprop(true).pretrain(false)
-                .build();
-
-        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
-        model.init();
-        model.fit(this.training_data);
-
-        INDArray output
-                = model.output((this.testing_data.getFeatureMatrix()));
-
-        Evaluation eval = new Evaluation(CLASSES_COUNT);
-        eval.eval(this.testing_data.getLabels(), output);
-        System.out.println(eval.stats());
-        ROC roc = new ROC(CLASSES_COUNT);
-        roc.eval(this.testing_data.getLabels(), output);
-        System.out.println(roc.stats());
-    }
-}
diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
index d8a5b18a6ac2110b44ccec1bc938a8e6f14e6400..f9f0fa0bcc4b71699b15e8b73d87f1cfcd4e313a 100644
--- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
+++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
@@ -1,8 +1,15 @@
 package be.cylab.java.wowa.training;
 
 import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 import org.nd4j.linalg.activations.Activation;
 
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+
 /**
  * Class for learn in neuronal network.
  */
@@ -20,20 +27,24 @@ public final class MainDL4J {
      * @param args
      */
     public static void main(final String[] args) {
-
-        String file_name = args[0];
-        double learning_rate = Double.parseDouble(args[1]);
+        int begin_neuron_number = Integer.parseInt(args[0]);
+        int neuron_number_end = Integer.parseInt(args[1]);
+        String data_file = args[5];
+        String expected_file = args[6];
+        double learning_rate = Double.parseDouble(args[2]);
         OptimizationAlgorithm optimization_algorithm;
         Activation activation_function;
+        int fold_number = 10;
+        int increase_ratio = 10;
 
-        if (args[2].matches("CONJUGATE_GRADIENT")) {
+        if (args[3].matches("CONJUGATE_GRADIENT")) {
             optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
-        } else if (args[2].matches("LBFGS")) {
+        } else if (args[3].matches("LBFGS")) {
             optimization_algorithm = OptimizationAlgorithm.LBFGS;
-        } else if (args[2].matches("LINE_GRADIENT_DESCENT")) {
+        } else if (args[3].matches("LINE_GRADIENT_DESCENT")) {
             optimization_algorithm
                     = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
-        } else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) {
+        } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) {
             optimization_algorithm
                     = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
         } else {
@@ -41,27 +52,48 @@ public final class MainDL4J {
                     "Not correct optimization algorithm");
         }
 
-        if (args[3].matches("RELU")) {
+        if (args[4].matches("RELU")) {
             activation_function = Activation.RELU;
-        } else if  (args[3].matches("SIGMOID")) {
+        } else if (args[4].matches("SIGMOID")) {
             activation_function = Activation.SIGMOID;
-        } else if (args[3].matches("TANH")) {
+        } else if (args[4].matches("TANH")) {
             activation_function = Activation.TANH;
         } else {
             throw new IllegalArgumentException(
                     "Not correct activation function");
         }
 
-        LearningNeuralNetwork lnn
-                = new LearningNeuralNetwork(file_name,
-                learning_rate,
-                optimization_algorithm,
-                activation_function);
-        for (int neurons_number = 5; neurons_number < 50; neurons_number++) {
-            System.out.println("Number of neurons in hidden layer : "
-                    + neurons_number);
-            lnn.learning(neurons_number);
+        for (int neuron_number = begin_neuron_number;
+             neuron_number < neuron_number_end; neuron_number++) {
+            double average_auc = 0;
+            System.out.println("Neuron number : " + neuron_number);
+            HyperParameters parameters
+                    = new HyperParameters(neuron_number, learning_rate,
+                    optimization_algorithm, activation_function);
+            NeuralNetwork nn = new NeuralNetwork(parameters);
+
+            HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
+                    data_file,
+                    expected_file,
+                    fold_number,
+                    increase_ratio);
+            for (Double d : map.values()) {
+                average_auc = average_auc + d;
+            }
+            try (OutputStreamWriter writer = new OutputStreamWriter(
+                    new FileOutputStream("Synthesis_average_AUC.txt", true),
+                    StandardCharsets.UTF_8)) {
+                writer.write("Neuron number : " + neuron_number
+                        + " Learning rate : " + learning_rate
+                        + " Algorithm : " + args[3]
+                        + " Activation function : " + args[4]
+                        + " Average AUC = "
+                        + (average_auc / fold_number) + "\n");
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
         }
 
+
     }
 }
diff --git a/src/main/java/be/cylab/java/wowa/training/MainTest.java b/src/main/java/be/cylab/java/wowa/training/MainTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..8a2f482f8d8191f2204adb06d67c54d7050597b5
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/MainTest.java
@@ -0,0 +1,91 @@
+package be.cylab.java.wowa.training;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.activations.Activation;
+
+import java.util.HashMap;
+import java.util.List;
+
+/**
+ * Main class for test and compare efficiency of nn and wt.
+ */
+public final class MainTest {
+
+    private MainTest() {
+
+    }
+
+    /**
+     * @param args
+     */
+    public static void main(final String[] args) {
+        int population_size = 100;
+        int crossover_rate = 60;
+        int muatation_rate = 15;
+        int max_generation_number = 120;
+        int selection_method
+                = TrainerParameters.SELECTION_METHOD_RWS;
+        int generation_population_method
+                = TrainerParameters.POPULATION_INITIALIZATION_RANDOM;
+        String data_file = "ressources/webshell_data.json";
+        String expected_file = "ressources/webshell_expected.json";
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(data_file);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(expected_file);
+
+        TrainerParameters wt_parameters = new TrainerParameters(
+                null,
+                population_size,
+                crossover_rate,
+                muatation_rate,
+                max_generation_number,
+                selection_method,
+                generation_population_method);
+        Trainer trainer = new Trainer(wt_parameters,
+                new SolutionDistance(5));
+
+        int neurons_number = 20;
+        double learning_rate = 0.01;
+        OptimizationAlgorithm algo = OptimizationAlgorithm.CONJUGATE_GRADIENT;
+        Activation activation_function = Activation.TANH;
+        HyperParameters nn_parameters = new HyperParameters(
+                neurons_number,
+                learning_rate,
+                algo,
+                activation_function
+        );
+        NeuralNetwork nn = new NeuralNetwork(nn_parameters);
+
+        TrainingDataset dataset = new TrainingDataset(data, expected);
+        List<TrainingDataset> folds = dataset.prepareFolds(10);
+
+        long start_time = System.currentTimeMillis();
+        System.out.println("Wowa training");
+        HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10);
+        long end_time = System.currentTimeMillis();
+        System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds");
+
+        start_time = System.currentTimeMillis();
+        System.out.println("Neural Network learning");
+        HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10);
+        end_time = System.currentTimeMillis();
+        System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds");
+
+        double nn_score = 0.0;
+        double wt_score = 0.0;
+        for (Double d : map_nn.values()) {
+            nn_score = nn_score + d;
+        }
+
+        System.out.println("Average AUC for Neural Network learning : "
+                + nn_score / 10);
+
+        for (Double d : map_wt.values()) {
+            wt_score = wt_score + d;
+        }
+
+        System.out.println("Average AUC for WOWA learning : " + wt_score / 10);
+    }
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java
new file mode 100644
index 0000000000000000000000000000000000000000..fc3d5135930e9a771aa157527d74deb356283ee2
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java
@@ -0,0 +1,278 @@
+package be.cylab.java.wowa.training;
+
+import org.datavec.api.records.reader.RecordReader;
+import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
+import org.datavec.api.split.FileSplit;
+import org.datavec.api.util.ClassPathResource;
+import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.eval.ROC;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.SplitTestAndTrain;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+
+/**
+ * Class to test neural network learning.
+ */
+
+public final class NeuralNetwork {
+
+    /**
+     * Class count.
+     */
+    public static final int CLASSES_COUNT = 2;
+    /**
+     * Features count.
+     */
+    public static final int FEATURES_COUNT = 5;
+
+
+    private HyperParameters parameters;
+
+
+    /**
+     * Default constructor.
+     *
+     * @param parameters
+     */
+    public NeuralNetwork(final HyperParameters parameters
+    ) {
+
+        this.parameters = parameters;
+    }
+
+    /**
+     * @param learning
+     * @return
+     */
+    MultiLayerNetwork learning(
+            final DataSet learning) {
+
+        int neurons_number = parameters.getNeuronsNumber();
+        MultiLayerConfiguration configuration
+                = new NeuralNetConfiguration.Builder()
+                .iterations(1000)
+                .activation(parameters.getActivationFunction())
+                .optimizationAlgo(
+                        parameters.getAlgorithm())
+                .weightInit(WeightInit.XAVIER)
+                .learningRate(parameters.getLearningRate())
+                .regularization(true).l2(0.0001)
+                .list()
+                .layer(0, new DenseLayer.Builder()
+                        .nIn(learning.getFeatures().columns())
+                        .nOut(neurons_number).build())
+                .layer(1, new DenseLayer.Builder().nIn(neurons_number)
+                        .nOut(neurons_number).build())
+                .layer(2, new OutputLayer.Builder(LossFunctions
+                        .LossFunction.NEGATIVELOGLIKELIHOOD)
+                        .activation(Activation.SOFTMAX)
+                        .nIn(neurons_number)
+                        .nOut(learning.getLabels().columns()).build())
+                .backprop(true).pretrain(false)
+                .build();
+
+        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
+        model.init();
+        model.fit(learning);
+
+        return model;
+    }
+
+    /**
+     * @param data
+     * @param expected
+     * @return
+     */
+    public MultiLayerNetwork run(
+            final List<List<Double>> data,
+            final List<Double> expected) {
+
+        DataSet all_data = Utils.createDataSet(data, expected);
+        //DataSet training_data
+        //        = prepareDataSetForTrainingAndTesting(all_data).getTrain();
+        return this.learning(all_data);
+
+    }
+
+    /**
+     * @param data_filename
+     * @param expected_filename
+     * @return
+     */
+    public MultiLayerNetwork run(
+            final String data_filename,
+            final String expected_filename) {
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(data_filename);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(expected_filename);
+        return this.run(data, expected);
+    }
+
+    /**
+     * @param filename
+     * @return
+     */
+    public MultiLayerNetwork runCSV(final String filename) {
+        try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
+            record_reader.initialize(new FileSplit(
+                    new ClassPathResource(filename).getFile()
+            ));
+            DataSetIterator iterator = new RecordReaderDataSetIterator(
+                    record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
+            DataSet all_data = iterator.next();
+            //DataSet training_data
+            //        = prepareDataSetForTrainingAndTesting(all_data).getTrain();
+            return learning(all_data);
+
+        } catch (IOException e) {
+            e.printStackTrace();
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+        return null;
+    }
+
+    /**
+     * Method to evaluate the performance of the model.
+     *
+     * @param testing
+     * @param network
+     */
+    public Double modelEvaluation(
+            final DataSet testing,
+            final MultiLayerNetwork network) {
+        INDArray output
+                = network.output((testing.getFeatureMatrix()));
+
+        Evaluation eval = new Evaluation(testing.getLabels().columns());
+        eval.eval(testing.getLabels(), output);
+        System.out.println(eval.stats());
+        ROC roc = new ROC(testing.getLabels().columns());
+        roc.eval(testing.getLabels(), output);
+        System.out.println(roc.stats());
+        return roc.calculateAUC();
+    }
+
+    /**
+     * @param testing
+     * @param network
+     * @return
+     */
+    public Double modelEvaluation(
+            final TrainingDataset testing,
+            final MultiLayerNetwork network) {
+        DataSet test
+                = Utils.createDataSet(testing.getData(), testing.getExpected());
+        return modelEvaluation(test, network);
+    }
+
+    /**
+     * @param data
+     * @return
+     */
+    SplitTestAndTrain prepareDataSetForTrainingAndTesting(final DataSet data) {
+        data.shuffle();
+        DataNormalization normalizer = new NormalizerStandardize();
+        normalizer.fit(data);
+        normalizer.transform(data);
+
+        return data.splitTestAndTrain(parameters.getPercentTestTrain());
+    }
+
+    /**
+     * @param data
+     * @param expected
+     * @param fold_number
+     * @param increase_ratio
+     * @return
+     */
+    public HashMap<MultiLayerNetwork, Double> runKFold(
+            final List<List<Double>> data,
+            final List<Double> expected,
+            final int fold_number,
+            final int increase_ratio) {
+        TrainingDataset dataset = new TrainingDataset(data, expected);
+        List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
+        HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
+        for (int i = 0; i < fold_number; i++) {
+            TrainingDataset testingfold = folds.get(i);
+            TrainingDataset learning_fold = new TrainingDataset();
+            for (int j = 0; j < fold_number; j++) {
+                if (j != i) {
+                    learning_fold.addFoldInDataset(folds, i);
+                }
+            }
+            TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
+                    increase_ratio);
+            MultiLayerNetwork nn = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = modelEvaluation(testingfold, nn);
+
+            map.put(nn, score);
+        }
+        return map;
+    }
+
+    /**
+     * @param filename_data
+     * @param filename_expected
+     * @param fold_number
+     * @param increase_ratio
+     * @return
+     */
+    public HashMap<MultiLayerNetwork, Double> runKFold(
+            final String filename_data,
+            final String filename_expected,
+            final int fold_number,
+            final int increase_ratio) {
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(filename_data);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(filename_expected);
+
+        return runKFold(data, expected, fold_number, increase_ratio);
+    }
+
+    HashMap<MultiLayerNetwork, Double> runKFold(
+            final List<TrainingDataset> prepared_folds,
+            final int increase_ratio) {
+        HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
+        for (int i = 0; i < prepared_folds.size(); i++) {
+            TrainingDataset testingfold = prepared_folds.get(i);
+            TrainingDataset learning_fold = new TrainingDataset();
+            for (int j = 0; j < prepared_folds.size(); j++) {
+                if (j != i) {
+                    learning_fold.addFoldInDataset(prepared_folds, i);
+                }
+            }
+            TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
+                    increase_ratio);
+            System.out.println("Fold number : " + (i + 1));
+            MultiLayerNetwork nn = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = modelEvaluation(testingfold, nn);
+            map.put(nn, score);
+        }
+        return map;
+    }
+
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index f42121513011caf37d326314aa884fe5de399301..fda997a226203456ae52646316344c90eed31f35 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -104,7 +104,7 @@ public class Trainer {
             final int fold_number,
             final int increase_ration_alert) {
         TrainingDataset dataset = new TrainingDataset(data, expected);
-        List<TrainingDataset> folds = prepareFolds(dataset, fold_number);
+        List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
         HashMap<AbstractSolution, Double> map = new HashMap<>();
         for (int i = 0; i < fold_number; i++) {
             TrainingDataset testing = folds.get(i);
@@ -114,9 +114,7 @@ public class Trainer {
                     learning.addFoldInDataset(folds, j);
                 }
             }
-            TrainingDataset dataset_increased = increaseTrueAlert(
-                    learning.getData(),
-                    learning.getExpected(),
+            TrainingDataset dataset_increased = learning.increaseTrueAlert(
                     increase_ration_alert);
             AbstractSolution sol = run(
                     dataset_increased.getData(),
@@ -160,83 +158,38 @@ public class Trainer {
     }
 
     /**
-     * Method to increase the number of true alert in data.
-     * to increase penalty to do not detect a true alert
-     *
-     * @param data
-     * @param expected
+     * @param prepared_folds
      * @param increase_ratio
+     * @return
      */
-    final TrainingDataset increaseTrueAlert(
-            final List<List<Double>> data,
-            final List<Double> expected,
+    HashMap<AbstractSolution, Double> runKFold(
+            final List<TrainingDataset> prepared_folds,
             final int increase_ratio) {
-        int data_size = expected.size();
-        for (int i = 0; i < data_size; i++) {
-            if (expected.get(i) == 1) {
-                for (int j = 0; j < increase_ratio - 1; j++) {
-                    expected.add(expected.get(i));
-                    data.add(data.get(i));
+        HashMap<AbstractSolution, Double> map = new HashMap<>();
+        for (int i = 0; i < prepared_folds.size(); i++) {
+            TrainingDataset testing = prepared_folds.get(i);
+            TrainingDataset learning = new TrainingDataset();
+            for (int j = 0; j < prepared_folds.size(); j++) {
+                if (j != i) {
+                    learning.addFoldInDataset(prepared_folds, j);
                 }
             }
-        }
-        return new TrainingDataset(data, expected);
-    }
-
-    /**
-     * Method to separate randomly the base dataset in X folds dataset.
-     *
-     * @param dataset
-     * @param fold_number
-     * @return
-     */
-     final List<TrainingDataset> prepareFolds(
-            final TrainingDataset dataset,
-            final int fold_number) {
-        //List<List<Double>> data = dataset.getData();
-        List<Double> expected = dataset.getExpected();
-        List<TrainingDataset> fold_dataset = new ArrayList<>();
-        //Check if it is rounded !!!!
-        int alert_number
-                = (int) Math.floor(Utils.sumListElements(expected)
-                / fold_number);
-        int no_alert_number = (int) (expected.size()
-                - Utils.sumListElements(expected)) / fold_number;
+            TrainingDataset dataset_increased = learning.increaseTrueAlert(
+                    increase_ratio);
+            AbstractSolution sol = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = sol.computeAUC(
+                    testing.getData(),
+                    testing.getExpected());
+            System.out.println("Fold number : " + (i + 1) + "AUC : " + score);
 
-        for (int i = 0; i < fold_number; i++) {
-            TrainingDataset tmp = new TrainingDataset();
-            int alert_counter = 0;
-            int no_alert_counter = 0;
-            while (tmp.getLength() < (alert_number + no_alert_number)) {
-                int index = Utils.randomInteger(0, dataset.getLength() - 1);
-                if (dataset.getExpected().get(index) == 1
-                        && alert_counter < alert_number) {
-                    tmp.addElementInDataset(dataset, index);
-                    dataset.removeElementInDataset(index);
-                    alert_counter++;
-                } else if (dataset.getExpected().get(index) == 0
-                        && no_alert_counter < no_alert_number) {
-                    tmp.addElementInDataset(dataset, index);
-                    dataset.removeElementInDataset(index);
-                    no_alert_counter++;
-                }
-            }
-            fold_dataset.add(tmp);
-        }
-        int fold_counter = 0;
-        while (dataset.getLength() > 0) {
-            int i = Utils.randomInteger(0, dataset.getLength() - 1);
-            fold_dataset.get(fold_counter).addElementInDataset(dataset, i);
-            dataset.removeElementInDataset(i);
-            if (fold_counter == fold_dataset.size() - 1) {
-                fold_counter = 0;
-            } else {
-                fold_counter++;
-            }
+            map.put(sol, score);
         }
-        return fold_dataset;
+        return map;
     }
 
+
     /**
      * Find the best element in the population based on its fitness score.
      *
@@ -390,6 +343,7 @@ public class Trainer {
     /**
      * Method used only for tests !!
      * This method generates random number (tos) by using a seed !
+     *
      * @param solutions
      * @param selected_elements
      * @param count
diff --git a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
index 2c75c85e1acbdf0ed1c9d5f888cbe615971e6d46..6d7ff6bb826666d911886309d534f36f32acc4b5 100644
--- a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
+++ b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
@@ -117,4 +117,76 @@ class TrainingDataset {
         this.length += data_to_add.get(index).getLength();
         return this;
     }
+
+    /**
+     * Method to separate randomly the base dataset in X folds dataset.
+     *
+     * @param fold_number
+     * @return
+     */
+    final List<TrainingDataset> prepareFolds(
+            final int fold_number) {
+        //List<List<Double>> data = dataset.getData();
+        List<Double> expected = this.getExpected();
+        List<TrainingDataset> fold_dataset = new ArrayList<>();
+        //Check if it is rounded !!!!
+        int alert_number
+                = (int) Math.floor(Utils.sumListElements(expected)
+                / fold_number);
+        int no_alert_number = (int) (expected.size()
+                - Utils.sumListElements(expected)) / fold_number;
+
+        for (int i = 0; i < fold_number; i++) {
+            TrainingDataset tmp = new TrainingDataset();
+            int alert_counter = 0;
+            int no_alert_counter = 0;
+            while (tmp.getLength() < (alert_number + no_alert_number)) {
+                int index = Utils.randomInteger(0, this.getLength() - 1);
+                if (this.getExpected().get(index) == 1
+                        && alert_counter < alert_number) {
+                    tmp.addElementInDataset(this, index);
+                    this.removeElementInDataset(index);
+                    alert_counter++;
+                } else if (this.getExpected().get(index) == 0
+                        && no_alert_counter < no_alert_number) {
+                    tmp.addElementInDataset(this, index);
+                    this.removeElementInDataset(index);
+                    no_alert_counter++;
+                }
+            }
+            fold_dataset.add(tmp);
+        }
+        int fold_counter = 0;
+        while (this.getLength() > 0) {
+            int i = Utils.randomInteger(0, this.getLength() - 1);
+            fold_dataset.get(fold_counter).addElementInDataset(this, i);
+            this.removeElementInDataset(i);
+            if (fold_counter == fold_dataset.size() - 1) {
+                fold_counter = 0;
+            } else {
+                fold_counter++;
+            }
+        }
+        return fold_dataset;
+    }
+
+    /**
+     * Method to increase the number of true alert in data.
+     * to increase penalty to do not detect a true alert
+     *
+     * @param increase_ratio
+     */
+    final TrainingDataset increaseTrueAlert(
+            final int increase_ratio) {
+        int data_size = expected.size();
+        for (int i = 0; i < data_size; i++) {
+            if (expected.get(i) == 1) {
+                for (int j = 0; j < increase_ratio - 1; j++) {
+                    expected.add(expected.get(i));
+                    data.add(data.get(i));
+                }
+            }
+        }
+        return new TrainingDataset(data, expected);
+    }
 }
diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java
index e8174f4be434425d68cfa1251cd77d97b3f8d493..1bcf4a6794e10731ac75884f57cc5773a990470d 100644
--- a/src/main/java/be/cylab/java/wowa/training/Utils.java
+++ b/src/main/java/be/cylab/java/wowa/training/Utils.java
@@ -4,6 +4,9 @@ import com.owlike.genson.GenericType;
 import com.owlike.genson.Genson;
 import info.debatty.java.aggregation.WOWA;
 import org.apache.commons.lang3.ArrayUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
 
 import java.io.FileOutputStream;
 import java.io.IOException;
@@ -569,4 +572,31 @@ final class Utils {
             e.printStackTrace();
         }
     }
+
+    static DataSet createDataSet(
+            final List<List<Double>> data,
+            final List<Double> expected) {
+        if (data.size() != expected.size()) {
+            throw new IllegalArgumentException(
+                    "Data and Expected must have the same size");
+        }
+        double[][] data_array = new double[data.size()][data.get(0).size()];
+        double[][] expected_array = new double[expected.size()][2];
+        for (int i = 0; i < data.size(); i++) {
+            if (expected.get(i) == 1.0) {
+                expected_array[i][0] = 1.0;
+                expected_array[i][1] = 0.0;
+            } else {
+                expected_array[i][0] = 0.0;
+                expected_array[i][1] = 1.0;
+            }
+
+            for (int j = 0; j < data.get(i).size(); j++) {
+                data_array[i][j] = data.get(i).get(j);
+            }
+        }
+        INDArray data_ind = Nd4j.create(data_array);
+        INDArray expected_ind = Nd4j.create(expected_array);
+        return new DataSet(data_ind, expected_ind);
+    }
 }
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
index fd5d815f5260415d4c216c6a49134e66f358bf64..41bcd04e790770c42e29660142594b4e4f709fdf 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
@@ -22,66 +22,6 @@ class TrainerTest {
     void testRun() {
     }
 
-    @Test
-    void testIncreaseTrueAlert() {
-        List<List<Double>> data = generateData(100, 5);
-        int original_data_set_size = data.size();
-        List<Double> expected = generateExpectedBinaryClassification(100);
-        int increase_ratio = 10;
-        int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
-        int number_of_no_alert = original_data_set_size - number_of_true_alert;
-        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
-
-        //Check if the length of the dataset us correct
-        //increase_ration * number_of_true_alert + expected.size()
-        assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
-
-        //Check the number of true_alert in the dataset
-        //increase_ration * number_of_true_alert + number_of_true_alert
-        assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
-
-        //Check if each rue alert elements are present (increase_ratio time) in he dataset
-        //Here, we check if each true alert element is present 10 more times than the original one
-        for (int i = 0; i < expected.size(); i++) {
-            if (ds.getExpected().get(i) == 1.0) {
-                int cnt = 0;
-                for (int j = 0; j < ds.getLength(); j++) {
-                    if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
-                        cnt++;
-                    }
-                }
-                assertEquals(increase_ratio, cnt);
-            }
-        }
-
-    }
-
-    @Test
-    void testPrepareFolds() {
-        int number_of_elements = 100;
-        List<List<Double>> data = generateData(number_of_elements,5);
-        List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
-        int increase_ratio = 3;
-        int fold_number = 10;
-        int number_of_alert = (int)(double)Utils.sumListElements(expected);
-        int number_of_no_alert = number_of_elements - number_of_alert;
-        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
-        List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
-        assertEquals(fold_number, folds.size());
-        for (int i = 0; i < folds.size(); i++) {
-            assertTrue(folds.get(i).getLength()
-                    == (number_of_alert * increase_ratio + number_of_no_alert)
-                    / fold_number || folds.get(i).getLength()
-                    == 1 + (number_of_alert * increase_ratio + number_of_no_alert)
-                    / fold_number);
-            assertTrue(Utils.sumListElements(folds.get(i).getExpected())
-                    == ((number_of_alert * increase_ratio) / fold_number) ||
-                    Utils.sumListElements(folds.get(i).getExpected())
-                            == 1 + ((number_of_alert * increase_ratio) / fold_number)
-                    || Utils.sumListElements(folds.get(i).getExpected())
-                    == 2 + ((number_of_alert * increase_ratio) / fold_number));
-        }
-    }
 
     @Test
     void testFindBestSolution() {
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
index a2eda8a8608d96c4f82146a3a909da0c961ce66f..94a5e59b57e4fc2a725de5eb412e7b424baedbf3 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Random;
 
 import static org.junit.jupiter.api.Assertions.*;
 
@@ -108,4 +109,100 @@ class TrainingDatasetTest {
             }
         }
     }
+
+    @Test
+    void testIncreaseTrueAlert() {
+        List<List<Double>> data = generateData(100, 5);
+        int original_data_set_size = data.size();
+        List<Double> expected = generateExpectedBinaryClassification(100);
+        int increase_ratio = 10;
+        int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = original_data_set_size - number_of_true_alert;
+        TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
+
+        //Check if the length of the dataset us correct
+        //increase_ration * number_of_true_alert + expected.size()
+        assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
+
+        //Check the number of true_alert in the dataset
+        //increase_ration * number_of_true_alert + number_of_true_alert
+        assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
+
+        //Check if each rue alert elements are present (increase_ratio time) in he dataset
+        //Here, we check if each true alert element is present 10 more times than the original one
+        for (int i = 0; i < expected.size(); i++) {
+            if (ds.getExpected().get(i) == 1.0) {
+                int cnt = 0;
+                for (int j = 0; j < ds.getLength(); j++) {
+                    if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
+                        cnt++;
+                    }
+                }
+                assertEquals(increase_ratio, cnt);
+            }
+        }
+
+    }
+
+    @Test
+    void testPrepareFolds() {
+        int number_of_elements = 100;
+        List<List<Double>> data = generateData(number_of_elements,5);
+        List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
+        int increase_ratio = 3;
+        int fold_number = 10;
+        int number_of_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = number_of_elements - number_of_alert;
+        TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
+        List<TrainingDataset> folds = ds.prepareFolds(fold_number);
+        assertEquals(fold_number, folds.size());
+        for (int i = 0; i < folds.size(); i++) {
+            assertTrue(folds.get(i).getLength()
+                    == (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number || folds.get(i).getLength()
+                    == 1 + (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number);
+            assertTrue(Utils.sumListElements(folds.get(i).getExpected())
+                    == ((number_of_alert * increase_ratio) / fold_number) ||
+                    Utils.sumListElements(folds.get(i).getExpected())
+                            == 1 + ((number_of_alert * increase_ratio) / fold_number)
+                    || Utils.sumListElements(folds.get(i).getExpected())
+                    == 2 + ((number_of_alert * increase_ratio) / fold_number));
+        }
+    }
+
+    static List<List<Double>> generateData(final int size, final int weight_number) {
+        Random rnd = new Random(5489);
+        List<List<Double>> data = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            List<Double> vector = new ArrayList<>();
+            for (int j = 0; j < weight_number; j++) {
+                vector.add(rnd.nextDouble());
+            }
+            data.add(vector);
+        }
+        return  data;
+    }
+
+    static List<Double> generateExpected(final int size) {
+        Random rnd = new Random(5768);
+        List<Double> expected = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            expected.add(rnd.nextDouble());
+        }
+        return expected;
+    }
+
+    static List<Double> generateExpectedBinaryClassification(final int size) {
+        Random rnd = new Random(5768);
+        List<Double> expected = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            if (rnd.nextDouble() <= 0.5) {
+                expected.add(new Double(0.0));
+            } else {
+                expected.add(new Double(1.0));
+            }
+        }
+        return  expected;
+    }
 }
\ No newline at end of file