diff --git a/.gitignore b/.gitignore
index 0b1592a6d8ec776d172a498cb860944b83173f10..9ba3e7030b1ec399f6cbc25cd0ab14f60011ac36 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,4 @@
 /nbproject/
 /target/
 /.idea/
-java-wowa-training.iml
+/java-wowa-training.iml
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 542a1bb1b357850381cdc258bbe5924933fce840..fba36d30c252c00a51ed3078a8c2b234aaa3392f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -22,20 +22,6 @@ test:mvn:jdk8:
     - mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
 
 
-
-test:mvn:jdk9:
-  stage: test
-  image: maven:3.5.3-jdk-9
-  script:
-    - mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
-
-
-test:mvn:jdk10:
-  stage: test
-  image: maven:3.5.3-jdk-10
-  script:
-    - mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
-
 ##mvn:jdk11:
 ##  image: maven:3.5.3-jdk-11
 ##  script:
diff --git a/java-wowa-training.iml b/java-wowa-training.iml
index 8e59e92c2185fc6f66fe8ad20f93896426f83069..e84bf0adeb5339a52ed7612df451c6b6208ea764 100644
--- a/java-wowa-training.iml
+++ b/java-wowa-training.iml
@@ -32,7 +32,97 @@
         <SOURCES />
       </library>
     </orderEntry>
-    <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" />
+    <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
+    <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
+    <orderEntry type="library" name="Maven: commons-io:commons-io:2.4" level="project" />
+    <orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.7" level="project" />
+    <orderEntry type="library" name="Maven: joda-time:joda-time:2.9.2" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:jackson:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.yaml:snakeyaml:1.12" level="project" />
+    <orderEntry type="library" name="Maven: org.codehaus.woodstox:stax2-api:3.1.4" level="project" />
+    <orderEntry type="library" name="Maven: org.projectlombok:lombok:1.16.10" level="project" />
+    <orderEntry type="library" name="Maven: org.freemarker:freemarker:2.3.23" level="project" />
+    <orderEntry type="library" name="Maven: org.reflections:reflections:0.9.10" level="project" />
+    <orderEntry type="library" name="Maven: org.javassist:javassist:3.19.0-GA" level="project" />
+    <orderEntry type="library" name="Maven: com.google.code.findbugs:annotations:2.0.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-common:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: com.github.stephenc.findbugs:findbugs-annotations:1.3.9-1" level="project" />
+    <orderEntry type="library" name="Maven: com.clearspring.analytics:stream:2.7.0" level="project" />
+    <orderEntry type="library" name="Maven: it.unimi.dsi:fastutil:6.5.7" level="project" />
+    <orderEntry type="library" name="Maven: net.sf.opencsv:opencsv:2.3" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-buffer:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-context:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: net.ericaro:neoitertools:1.0.0" level="project" />
+    <orderEntry type="library" name="Maven: junit:junit:4.8.2" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:linux-x86_64:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco:javacpp:1.3.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:0.2.19-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:linux-x86_64:0.2.19-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-nn:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-jackson:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-base64:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: commons-net:commons-net:3.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-core:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:nearestneighbor-core:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-modelimport:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5-platform:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-ppc64le:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:macosx-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: com.google.guava:guava:20.0" level="project" />
+    <orderEntry type="library" name="Maven: org.datavec:datavec-data-image:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: com.github.jai-imageio:jai-imageio-core:1.3.0" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-jpeg:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-core:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-metadata:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-lang:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-io:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-image:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-tiff:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-psd:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-bmp:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco:javacv:1.3.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:ffmpeg:3.2.1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flycapture:2.9.3.43-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libdc1394:2.2.4-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect:0.5.3-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect2:0.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:librealsense:1.9.6-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:videoinput:0.200-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:artoolkitplus:2.3.1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flandmark:1.07-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv-platform:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-arm:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-armhf:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-ppc64le:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:macosx-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica-platform:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-arm:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-armhf:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-ppc64le:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:macosx-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-ui-components:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: commons-codec:commons-codec:1.10" level="project" />
     <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" />
     <orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" />
     <orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" />
@@ -52,5 +142,96 @@
     <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" />
     <orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" />
     <orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" />
+    <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
+    <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
+    <orderEntry type="library" name="Maven: commons-io:commons-io:2.4" level="project" />
+    <orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.7" level="project" />
+    <orderEntry type="library" name="Maven: joda-time:joda-time:2.9.2" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:jackson:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.yaml:snakeyaml:1.12" level="project" />
+    <orderEntry type="library" name="Maven: org.codehaus.woodstox:stax2-api:3.1.4" level="project" />
+    <orderEntry type="library" name="Maven: org.projectlombok:lombok:1.16.10" level="project" />
+    <orderEntry type="library" name="Maven: org.freemarker:freemarker:2.3.23" level="project" />
+    <orderEntry type="library" name="Maven: org.reflections:reflections:0.9.10" level="project" />
+    <orderEntry type="library" name="Maven: org.javassist:javassist:3.19.0-GA" level="project" />
+    <orderEntry type="library" name="Maven: com.google.code.findbugs:annotations:2.0.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-common:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: com.github.stephenc.findbugs:findbugs-annotations:1.3.9-1" level="project" />
+    <orderEntry type="library" name="Maven: com.clearspring.analytics:stream:2.7.0" level="project" />
+    <orderEntry type="library" name="Maven: it.unimi.dsi:fastutil:6.5.7" level="project" />
+    <orderEntry type="library" name="Maven: net.sf.opencsv:opencsv:2.3" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-buffer:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-context:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: net.ericaro:neoitertools:1.0.0" level="project" />
+    <orderEntry type="library" name="Maven: junit:junit:4.8.2" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:linux-x86_64:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco:javacpp:1.3.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:0.2.19-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:linux-x86_64:0.2.19-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-native-api:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-nn:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-jackson:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.nd4j:nd4j-base64:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: commons-net:commons-net:3.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-core:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:nearestneighbor-core:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-modelimport:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5-platform:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-ppc64le:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:macosx-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86_64:1.10.0-patch1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: com.google.guava:guava:20.0" level="project" />
+    <orderEntry type="library" name="Maven: org.datavec:datavec-data-image:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: com.github.jai-imageio:jai-imageio-core:1.3.0" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-jpeg:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-core:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-metadata:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-lang:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-io:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-image:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-tiff:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-psd:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-bmp:3.1.1" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco:javacv:1.3.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:ffmpeg:3.2.1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flycapture:2.9.3.43-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libdc1394:2.2.4-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect:0.5.3-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect2:0.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:librealsense:1.9.6-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:videoinput:0.200-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:artoolkitplus:2.3.1-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flandmark:1.07-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv-platform:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-arm:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-armhf:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-ppc64le:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:macosx-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86_64:3.2.0-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica-platform:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-arm:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-armhf:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-ppc64le:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:macosx-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86_64:1.73-1.3" level="project" />
+    <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-ui-components:0.9.1" level="project" />
+    <orderEntry type="library" name="Maven: commons-codec:commons-codec:1.10" level="project" />
   </component>
 </module>
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index fc98116fb91ba5d89717aa502956bc26e1bcf104..3db99b868ce937228db506f39f03f93b4ca07baa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -98,6 +98,43 @@
             <version>0.0.4</version>
         </dependency>
 
+        <dependency>
+            <groupId>org.datavec</groupId>
+            <artifactId>datavec-api</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.nd4j</groupId>
+            <artifactId>nd4j-api</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.nd4j</groupId>
+            <artifactId>nd4j-native</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.deeplearning4j</groupId>
+            <artifactId>deeplearning4j-nn</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
+
+        <dependency>
+            <groupId>org.deeplearning4j</groupId>
+            <artifactId>deeplearning4j-core</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.nd4j</groupId>
+            <artifactId>nd4j-common</artifactId>
+            <version>0.9.1</version>
+        </dependency>
+
 
     </dependencies>
 
@@ -196,6 +233,7 @@
                             <ignoredDependencies>
                                 <ignoreDependency>org.junit.jupiter:junit-jupiter-engine:*</ignoreDependency>
                                 <ignoreDependency>org.junit.jupiter:junit-jupiter-params:*</ignoreDependency>
+                                <ignoredDependency>org.nd4j:nd4j-native</ignoredDependency>
                             </ignoredDependencies>
                         </configuration>
                         <goals>
diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java
index d1055c0ba01da1379640af927d2298bbb5edc1e3..9b013ce2d2dfd8b8740046ea1cb6f166f4b67fcd 100644
--- a/src/main/java/be/cylab/java/wowa/training/Example.java
+++ b/src/main/java/be/cylab/java/wowa/training/Example.java
@@ -48,8 +48,9 @@ public final class Example {
                     "Initialization must be RANDOM or QUASI_RANDOM");
         }
 
-        String data_file = "./ressources/webshell_data.json";
-        String expected_file = "./ressources/webshell_expected.json";
+        String data_file = "./ressources/webshell_data_new_version.json";
+        String expected_file
+                = "./ressources/webshell_expected_new_version.json";
         Logger logger = Logger.getLogger(Trainer.class.getName());
         logger.setLevel(Level.INFO);
         StreamHandler handler = new StreamHandler(System.out,
@@ -66,8 +67,8 @@ public final class Example {
 
         long start_time = System.currentTimeMillis();
         AbstractSolution solution = trainer.run(
-                "./ressources/webshell_data.json",
-                "./ressources/webshell_expected.json");
+                data_file,
+                expected_file);
         System.out.println(solution);
         long end_time = System.currentTimeMillis();
         logger.log(Level.INFO, "Execution time : "
@@ -78,11 +79,14 @@ public final class Example {
                 data_file,
                 expected_file,
                 10,
-                10);
+                2);
 
+        double average_auc = 0;
         for (Map.Entry val : solutions.entrySet()) {
             System.out.println(val);
+            average_auc = average_auc + (double) val.getValue();
         }
+        logger.log(Level.INFO, "Average AUC : " + average_auc / 10);
 
     }
 
diff --git a/src/main/java/be/cylab/java/wowa/training/HyperParameters.java b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java
new file mode 100644
index 0000000000000000000000000000000000000000..9915e6e3f926f7bb17b4cfdb4290aefbf5c0be8a
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java
@@ -0,0 +1,140 @@
+package be.cylab.java.wowa.training;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.nd4j.linalg.activations.Activation;
+
+/**
+ * Parameters of the NeuralNetwork class.
+ */
+public class HyperParameters {
+    private int neurons_number;
+    private double learning_rate;
+    private OptimizationAlgorithm algorithm;
+    private Activation activation_function;
+    private double percent_test_train;
+
+    /**
+     * Default constructor.
+     * @param neurons_number
+     * @param learning_rate
+     * @param algorithm
+     * @param activation_function
+     * @param percent_test_train
+     */
+    public HyperParameters(
+            final int neurons_number,
+            final double learning_rate,
+            final OptimizationAlgorithm algorithm,
+            final Activation activation_function,
+            final double percent_test_train) {
+        setNeuronsNumber(neurons_number);
+        setLearningRate(learning_rate);
+        setAlgorithm(algorithm);
+        setActivationFunction(activation_function);
+        setPercentTestTrain(percent_test_train);
+    }
+
+    /**
+     * Constructor without percent_train_test value.
+     * @param neurons_number
+     * @param learning_rate
+     * @param algorithm
+     * @param activation_function
+     */
+    public HyperParameters(
+            final int neurons_number,
+            final double learning_rate,
+            final OptimizationAlgorithm algorithm,
+            final Activation activation_function) {
+        this(neurons_number, learning_rate, algorithm,
+                activation_function, 99);
+    }
+
+    /**
+     * Getter for neuron_number.
+     * @return
+     */
+    public final int getNeuronsNumber() {
+        return neurons_number;
+    }
+
+    /**
+     * Getter for learning rate.
+     * @return
+     */
+    public final double getLearningRate() {
+        return learning_rate;
+    }
+
+    /**
+     * Getter for backpropagation algorithm.
+     * @return
+     */
+    public final OptimizationAlgorithm getAlgorithm() {
+        return algorithm;
+    }
+
+    /**
+     * Getter for activation function.
+     * @return
+     */
+    public final Activation getActivationFunction() {
+        return activation_function;
+    }
+
+    /**
+     * Getter for percent_test_train.
+     * @return
+     */
+    public final double getPercentTestTrain() {
+        return percent_test_train;
+    }
+
+    /**
+     * @param neurons_number
+     */
+    public final void setNeuronsNumber(final int neurons_number) {
+        if (neurons_number < 5 || neurons_number > 80) {
+            throw new IllegalArgumentException(
+                    "Neuron number must be between 5 and 80");
+        }
+        this.neurons_number = neurons_number;
+    }
+
+    /**
+     * @param learning_rate
+     */
+    public final void setLearningRate(final double learning_rate) {
+        if (learning_rate <= 0.0 || learning_rate >= 1.0) {
+            throw new IllegalArgumentException(
+                    "Learning rate must be between 0 and 1");
+        }
+        this.learning_rate = learning_rate;
+    }
+
+    /**
+     * @param algorithm
+     */
+    public final void setAlgorithm(final OptimizationAlgorithm algorithm) {
+        this.algorithm = algorithm;
+    }
+
+    /**
+     * @param activation_function
+     */
+    public final void setActivationFunction(
+            final Activation activation_function) {
+        this.activation_function = activation_function;
+    }
+
+    /**
+     * @param percent_test_train
+     */
+    public final void setPercentTestTrain(final double percent_test_train) {
+        if (percent_test_train <= 0 || percent_test_train >= 100) {
+            throw new IllegalArgumentException(
+                    "Percentage of train must be between 0 and 100");
+        }
+        this.percent_test_train = percent_test_train;
+    }
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
new file mode 100644
index 0000000000000000000000000000000000000000..b5be71399cba6f65f3343a53a08b8ae7edc0ccb9
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
@@ -0,0 +1,99 @@
+package be.cylab.java.wowa.training;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.activations.Activation;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+
+/**
+ * Class for learn in neuronal network.
+ */
+public final class MainDL4J {
+    /**
+     * Default constructor.
+     */
+    private MainDL4J() {
+
+    }
+
+    /**
+     * Main class for neural-network learning.
+     *
+     * @param args
+     */
+    public static void main(final String[] args) {
+        int begin_neuron_number = Integer.parseInt(args[0]);
+        int neuron_number_end = Integer.parseInt(args[1]);
+        String data_file = args[5];
+        String expected_file = args[6];
+        double learning_rate = Double.parseDouble(args[2]);
+        OptimizationAlgorithm optimization_algorithm;
+        Activation activation_function;
+        int fold_number = 10;
+        int increase_ratio = 10;
+
+        if (args[3].matches("CONJUGATE_GRADIENT")) {
+            optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
+        } else if (args[3].matches("LBFGS")) {
+            optimization_algorithm = OptimizationAlgorithm.LBFGS;
+        } else if (args[3].matches("LINE_GRADIENT_DESCENT")) {
+            optimization_algorithm
+                    = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
+        } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) {
+            optimization_algorithm
+                    = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
+        } else {
+            throw new IllegalArgumentException(
+                    "Not correct optimization algorithm");
+        }
+
+        if (args[4].matches("RELU")) {
+            activation_function = Activation.RELU;
+        } else if (args[4].matches("SIGMOID")) {
+            activation_function = Activation.SIGMOID;
+        } else if (args[4].matches("TANH")) {
+            activation_function = Activation.TANH;
+        } else {
+            throw new IllegalArgumentException(
+                    "Not correct activation function");
+        }
+
+        for (int neuron_number = begin_neuron_number;
+             neuron_number <= neuron_number_end; neuron_number++) {
+            double average_auc = 0;
+            System.out.println("Neuron number : " + neuron_number);
+            HyperParameters parameters
+                    = new HyperParameters(neuron_number, learning_rate,
+                    optimization_algorithm, activation_function);
+            NeuralNetwork nn = new NeuralNetwork(parameters);
+
+            HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
+                    data_file,
+                    expected_file,
+                    fold_number,
+                    increase_ratio);
+            for (Double d : map.values()) {
+                average_auc = average_auc + d;
+            }
+            try (OutputStreamWriter writer = new OutputStreamWriter(
+                    new FileOutputStream("Synthesis_average_AUC.txt", true),
+                    StandardCharsets.UTF_8)) {
+                writer.write("Neuron number : " + neuron_number
+                        + " Learning rate : " + learning_rate
+                        + " Algorithm : " + args[3]
+                        + " Activation function : " + args[4]
+                        + " Average AUC = "
+                        + (average_auc / fold_number) + "\n");
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+
+
+    }
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/MainTest.java b/src/main/java/be/cylab/java/wowa/training/MainTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..4b940c4c17f78be9c0211c2b2e65652d938989e5
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/MainTest.java
@@ -0,0 +1,93 @@
+package be.cylab.java.wowa.training;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.activations.Activation;
+
+import java.util.HashMap;
+import java.util.List;
+
+/**
+ * Main class for test and compare efficiency of nn and wt.
+ */
+public final class MainTest {
+
+    private MainTest() {
+
+    }
+
+    /**
+     * @param args
+     */
+    public static void main(final String[] args) {
+        int population_size = 100;
+        int crossover_rate = 60;
+        int muatation_rate = 15;
+        int max_generation_number = 120;
+        int selection_method
+                = TrainerParameters.SELECTION_METHOD_RWS;
+        int generation_population_method
+                = TrainerParameters.POPULATION_INITIALIZATION_RANDOM;
+        String data_file = "./ressources/webshell_data_new_version.json";
+        String expected_file = "ressources/webshell_expected_new_version.json";
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(data_file);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(expected_file);
+
+        TrainerParameters wt_parameters = new TrainerParameters(
+                null,
+                population_size,
+                crossover_rate,
+                muatation_rate,
+                max_generation_number,
+                selection_method,
+                generation_population_method);
+        Trainer trainer = new Trainer(wt_parameters,
+                new SolutionDistance(5));
+
+        int neurons_number = 20;
+        double learning_rate = 0.01;
+        OptimizationAlgorithm algo = OptimizationAlgorithm.CONJUGATE_GRADIENT;
+        Activation activation_function = Activation.TANH;
+        HyperParameters nn_parameters = new HyperParameters(
+                neurons_number,
+                learning_rate,
+                algo,
+                activation_function
+        );
+        NeuralNetwork nn = new NeuralNetwork(nn_parameters);
+
+        TrainingDataset dataset = new TrainingDataset(data, expected);
+        List<TrainingDataset> folds = dataset.prepareFolds(10);
+
+        long start_time = System.currentTimeMillis();
+        System.out.println("Wowa training");
+        HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 2);
+        long end_time = System.currentTimeMillis();
+        System.out.println("Execution time : " + (end_time - start_time) / 1000
+                + " seconds");
+
+        start_time = System.currentTimeMillis();
+        System.out.println("Neural Network learning");
+        HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10);
+        end_time = System.currentTimeMillis();
+        System.out.println("Execution time : " + (end_time - start_time) / 1000
+                + " seconds");
+
+        double nn_score = 0.0;
+        double wt_score = 0.0;
+        for (Double d : map_nn.values()) {
+            nn_score = nn_score + d;
+        }
+
+        System.out.println("Average AUC for Neural Network learning : "
+                + nn_score / 10);
+
+        for (Double d : map_wt.values()) {
+            wt_score = wt_score + d;
+        }
+
+        System.out.println("Average AUC for WOWA learning : " + wt_score / 10);
+    }
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java
new file mode 100644
index 0000000000000000000000000000000000000000..35d6a1068444aa9f0321fb31029531690b6daeb5
--- /dev/null
+++ b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java
@@ -0,0 +1,278 @@
+package be.cylab.java.wowa.training;
+
+import org.datavec.api.records.reader.RecordReader;
+import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
+import org.datavec.api.split.FileSplit;
+import org.datavec.api.util.ClassPathResource;
+import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.eval.ROC;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.SplitTestAndTrain;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+
+/**
+ * Class to test neural network learning.
+ */
+
+public final class NeuralNetwork {
+
+    /**
+     * Class count.
+     */
+    public static final int CLASSES_COUNT = 2;
+    /**
+     * Features count.
+     */
+    public static final int FEATURES_COUNT = 5;
+
+
+    private HyperParameters parameters;
+
+
+    /**
+     * Default constructor.
+     *
+     * @param parameters
+     */
+    public NeuralNetwork(final HyperParameters parameters
+    ) {
+
+        this.parameters = parameters;
+    }
+
+    /**
+     * @param learning
+     * @return
+     */
+    MultiLayerNetwork learning(
+            final DataSet learning) {
+
+        int neurons_number = parameters.getNeuronsNumber();
+        MultiLayerConfiguration configuration
+                = new NeuralNetConfiguration.Builder()
+                .iterations(1000)
+                .activation(parameters.getActivationFunction())
+                .optimizationAlgo(
+                        parameters.getAlgorithm())
+                .weightInit(WeightInit.XAVIER)
+                .learningRate(parameters.getLearningRate())
+                .regularization(true).l2(0.0001)
+                .list()
+                .layer(0, new DenseLayer.Builder()
+                        .nIn(learning.getFeatures().columns())
+                        .nOut(neurons_number).build())
+                .layer(1, new DenseLayer.Builder().nIn(neurons_number)
+                        .nOut(neurons_number).build())
+                .layer(2, new OutputLayer.Builder(LossFunctions
+                        .LossFunction.NEGATIVELOGLIKELIHOOD)
+                        .activation(Activation.SOFTMAX)
+                        .nIn(neurons_number)
+                        .nOut(learning.getLabels().columns()).build())
+                .backprop(true).pretrain(false)
+                .build();
+
+        MultiLayerNetwork model = new MultiLayerNetwork(configuration);
+        model.init();
+        model.fit(learning);
+
+        return model;
+    }
+
+    /**
+     * @param data
+     * @param expected
+     * @return
+     */
+    public MultiLayerNetwork run(
+            final List<List<Double>> data,
+            final List<Double> expected) {
+
+        DataSet all_data = Utils.createDataSet(data, expected);
+        //DataSet training_data
+        //        = prepareDataSetForTrainingAndTesting(all_data).getTrain();
+        return this.learning(all_data);
+
+    }
+
+    /**
+     * @param data_filename
+     * @param expected_filename
+     * @return
+     */
+    public MultiLayerNetwork run(
+            final String data_filename,
+            final String expected_filename) {
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(data_filename);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(expected_filename);
+        return this.run(data, expected);
+    }
+
+    /**
+     * @param filename
+     * @return
+     */
+    public MultiLayerNetwork runCSV(final String filename) {
+        try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
+            record_reader.initialize(new FileSplit(
+                    new ClassPathResource(filename).getFile()
+            ));
+            DataSetIterator iterator = new RecordReaderDataSetIterator(
+                    record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
+            DataSet all_data = iterator.next();
+            //DataSet training_data
+            //       = prepareDataSetForTrainingAndTesting(all_data).getTrain();
+            return learning(all_data);
+
+        } catch (IOException e) {
+            e.printStackTrace();
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+        return null;
+    }
+
+    /**
+     * Method to evaluate the performance of the model.
+     *
+     * @param testing
+     * @param network
+     */
+    public Double modelEvaluation(
+            final DataSet testing,
+            final MultiLayerNetwork network) {
+        INDArray output
+                = network.output((testing.getFeatureMatrix()));
+
+        Evaluation eval = new Evaluation(testing.getLabels().columns());
+        eval.eval(testing.getLabels(), output);
+        System.out.println(eval.stats());
+        ROC roc = new ROC(testing.getLabels().columns());
+        roc.eval(testing.getLabels(), output);
+        System.out.println(roc.stats());
+        return roc.calculateAUC();
+    }
+
+    /**
+     * @param testing
+     * @param network
+     * @return
+     */
+    public Double modelEvaluation(
+            final TrainingDataset testing,
+            final MultiLayerNetwork network) {
+        DataSet test
+                = Utils.createDataSet(testing.getData(), testing.getExpected());
+        return modelEvaluation(test, network);
+    }
+
+    /**
+     * @param data
+     * @return
+     */
+    SplitTestAndTrain prepareDataSetForTrainingAndTesting(final DataSet data) {
+        data.shuffle();
+        DataNormalization normalizer = new NormalizerStandardize();
+        normalizer.fit(data);
+        normalizer.transform(data);
+
+        return data.splitTestAndTrain(parameters.getPercentTestTrain());
+    }
+
+    /**
+     * @param data
+     * @param expected
+     * @param fold_number
+     * @param increase_ratio
+     * @return
+     */
+    public HashMap<MultiLayerNetwork, Double> runKFold(
+            final List<List<Double>> data,
+            final List<Double> expected,
+            final int fold_number,
+            final int increase_ratio) {
+        TrainingDataset dataset = new TrainingDataset(data, expected);
+        List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
+        HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
+        for (int i = 0; i < fold_number; i++) {
+            TrainingDataset testingfold = folds.get(i);
+            TrainingDataset learning_fold = new TrainingDataset();
+            for (int j = 0; j < fold_number; j++) {
+                if (j != i) {
+                    learning_fold.addFoldInDataset(folds, i);
+                }
+            }
+            TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
+                    increase_ratio);
+            MultiLayerNetwork nn = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = modelEvaluation(testingfold, nn);
+
+            map.put(nn, score);
+        }
+        return map;
+    }
+
+    /**
+     * @param filename_data
+     * @param filename_expected
+     * @param fold_number
+     * @param increase_ratio
+     * @return
+     */
+    public HashMap<MultiLayerNetwork, Double> runKFold(
+            final String filename_data,
+            final String filename_expected,
+            final int fold_number,
+            final int increase_ratio) {
+        List<List<Double>> data
+                = Utils.convertJsonToDataForTrainer(filename_data);
+        List<Double> expected
+                = Utils.convertJsonToExpectedForTrainer(filename_expected);
+
+        return runKFold(data, expected, fold_number, increase_ratio);
+    }
+
+    HashMap<MultiLayerNetwork, Double> runKFold(
+            final List<TrainingDataset> prepared_folds,
+            final int increase_ratio) {
+        HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
+        for (int i = 0; i < prepared_folds.size(); i++) {
+            TrainingDataset testingfold = prepared_folds.get(i);
+            TrainingDataset learning_fold = new TrainingDataset();
+            for (int j = 0; j < prepared_folds.size(); j++) {
+                if (j != i) {
+                    learning_fold.addFoldInDataset(prepared_folds, i);
+                }
+            }
+            TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
+                    increase_ratio);
+            System.out.println("Fold number : " + (i + 1));
+            MultiLayerNetwork nn = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = modelEvaluation(testingfold, nn);
+            map.put(nn, score);
+        }
+        return map;
+    }
+
+}
diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java
index 6eb13d1c6d6cc09088ffd3bf1dfdc0a7499c3de0..87c45ff5245a0a732f92122a9b087ccf8437a575 100644
--- a/src/main/java/be/cylab/java/wowa/training/Trainer.java
+++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java
@@ -104,7 +104,7 @@ public class Trainer {
             final int fold_number,
             final int increase_ration_alert) {
         TrainingDataset dataset = new TrainingDataset(data, expected);
-        List<TrainingDataset> folds = prepareFolds(dataset, fold_number);
+        List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
         HashMap<AbstractSolution, Double> map = new HashMap<>();
         for (int i = 0; i < fold_number; i++) {
             TrainingDataset testing = folds.get(i);
@@ -114,9 +114,7 @@ public class Trainer {
                     learning.addFoldInDataset(folds, j);
                 }
             }
-            TrainingDataset dataset_increased = increaseTrueAlert(
-                    learning.getData(),
-                    learning.getExpected(),
+            TrainingDataset dataset_increased = learning.increaseTrueAlert(
                     increase_ration_alert);
             AbstractSolution sol = run(
                     dataset_increased.getData(),
@@ -132,6 +130,8 @@ public class Trainer {
                     "Fold_" + (i + 1) + ".txt",
                     sol);
             map.put(sol, score);
+            parameters.getLogger().log(Level.WARNING,
+                    "Solution fold " + i + ":" + sol + " - AUC: " + score);
         }
         return map;
 
@@ -158,83 +158,38 @@ public class Trainer {
     }
 
     /**
-     * Method to increase the number of true alert in data.
-     * to increase penalty to do not detect a true alert
-     *
-     * @param data
-     * @param expected
+     * @param prepared_folds
      * @param increase_ratio
+     * @return
      */
-    final TrainingDataset increaseTrueAlert(
-            final List<List<Double>> data,
-            final List<Double> expected,
+    final HashMap<AbstractSolution, Double> runKFold(
+            final List<TrainingDataset> prepared_folds,
             final int increase_ratio) {
-        int data_size = expected.size();
-        for (int i = 0; i < data_size; i++) {
-            if (expected.get(i) == 1) {
-                for (int j = 0; j < increase_ratio - 1; j++) {
-                    expected.add(expected.get(i));
-                    data.add(data.get(i));
+        HashMap<AbstractSolution, Double> map = new HashMap<>();
+        for (int i = 0; i < prepared_folds.size(); i++) {
+            TrainingDataset testing = prepared_folds.get(i);
+            TrainingDataset learning = new TrainingDataset();
+            for (int j = 0; j < prepared_folds.size(); j++) {
+                if (j != i) {
+                    learning.addFoldInDataset(prepared_folds, j);
                 }
             }
-        }
-        return new TrainingDataset(data, expected);
-    }
-
-    /**
-     * Method to separate randomly the base dataset in X folds dataset.
-     *
-     * @param dataset
-     * @param fold_number
-     * @return
-     */
-     final List<TrainingDataset> prepareFolds(
-            final TrainingDataset dataset,
-            final int fold_number) {
-        //List<List<Double>> data = dataset.getData();
-        List<Double> expected = dataset.getExpected();
-        List<TrainingDataset> fold_dataset = new ArrayList<>();
-        //Check if it is rounded !!!!
-        int alert_number
-                = (int) Math.floor(Utils.sumListElements(expected)
-                / fold_number);
-        int no_alert_number = (int) (expected.size()
-                - Utils.sumListElements(expected)) / fold_number;
+            TrainingDataset dataset_increased = learning.increaseTrueAlert(
+                    increase_ratio);
+            AbstractSolution sol = run(
+                    dataset_increased.getData(),
+                    dataset_increased.getExpected());
+            Double score = sol.computeAUC(
+                    testing.getData(),
+                    testing.getExpected());
+            System.out.println("Fold number : " + (i + 1) + "AUC : " + score);
 
-        for (int i = 0; i < fold_number; i++) {
-            TrainingDataset tmp = new TrainingDataset();
-            int alert_counter = 0;
-            int no_alert_counter = 0;
-            while (tmp.getLength() < (alert_number + no_alert_number)) {
-                int index = Utils.randomInteger(0, dataset.getLength() - 1);
-                if (dataset.getExpected().get(index) == 1
-                        && alert_counter < alert_number) {
-                    tmp.addElementInDataset(dataset, index);
-                    dataset.removeElementInDataset(index);
-                    alert_counter++;
-                } else if (dataset.getExpected().get(index) == 0
-                        && no_alert_counter < no_alert_number) {
-                    tmp.addElementInDataset(dataset, index);
-                    dataset.removeElementInDataset(index);
-                    no_alert_counter++;
-                }
-            }
-            fold_dataset.add(tmp);
-        }
-        int fold_counter = 0;
-        while (dataset.getLength() > 0) {
-            int i = Utils.randomInteger(0, dataset.getLength() - 1);
-            fold_dataset.get(fold_counter).addElementInDataset(dataset, i);
-            dataset.removeElementInDataset(i);
-            if (fold_counter == fold_dataset.size() - 1) {
-                fold_counter = 0;
-            } else {
-                fold_counter++;
-            }
+            map.put(sol, score);
         }
-        return fold_dataset;
+        return map;
     }
 
+
     /**
      * Find the best element in the population based on its fitness score.
      *
@@ -388,6 +343,7 @@ public class Trainer {
     /**
      * Method used only for tests !!
      * This method generates random number (tos) by using a seed !
+     *
      * @param solutions
      * @param selected_elements
      * @param count
diff --git a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
index 2c75c85e1acbdf0ed1c9d5f888cbe615971e6d46..6d7ff6bb826666d911886309d534f36f32acc4b5 100644
--- a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
+++ b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java
@@ -117,4 +117,76 @@ class TrainingDataset {
         this.length += data_to_add.get(index).getLength();
         return this;
     }
+
+    /**
+     * Method to separate randomly the base dataset in X folds dataset.
+     *
+     * @param fold_number
+     * @return
+     */
+    final List<TrainingDataset> prepareFolds(
+            final int fold_number) {
+        //List<List<Double>> data = dataset.getData();
+        List<Double> expected = this.getExpected();
+        List<TrainingDataset> fold_dataset = new ArrayList<>();
+        //Check if it is rounded !!!!
+        int alert_number
+                = (int) Math.floor(Utils.sumListElements(expected)
+                / fold_number);
+        int no_alert_number = (int) (expected.size()
+                - Utils.sumListElements(expected)) / fold_number;
+
+        for (int i = 0; i < fold_number; i++) {
+            TrainingDataset tmp = new TrainingDataset();
+            int alert_counter = 0;
+            int no_alert_counter = 0;
+            while (tmp.getLength() < (alert_number + no_alert_number)) {
+                int index = Utils.randomInteger(0, this.getLength() - 1);
+                if (this.getExpected().get(index) == 1
+                        && alert_counter < alert_number) {
+                    tmp.addElementInDataset(this, index);
+                    this.removeElementInDataset(index);
+                    alert_counter++;
+                } else if (this.getExpected().get(index) == 0
+                        && no_alert_counter < no_alert_number) {
+                    tmp.addElementInDataset(this, index);
+                    this.removeElementInDataset(index);
+                    no_alert_counter++;
+                }
+            }
+            fold_dataset.add(tmp);
+        }
+        int fold_counter = 0;
+        while (this.getLength() > 0) {
+            int i = Utils.randomInteger(0, this.getLength() - 1);
+            fold_dataset.get(fold_counter).addElementInDataset(this, i);
+            this.removeElementInDataset(i);
+            if (fold_counter == fold_dataset.size() - 1) {
+                fold_counter = 0;
+            } else {
+                fold_counter++;
+            }
+        }
+        return fold_dataset;
+    }
+
+    /**
+     * Method to increase the number of true alert in data.
+     * to increase penalty to do not detect a true alert
+     *
+     * @param increase_ratio
+     */
+    final TrainingDataset increaseTrueAlert(
+            final int increase_ratio) {
+        int data_size = expected.size();
+        for (int i = 0; i < data_size; i++) {
+            if (expected.get(i) == 1) {
+                for (int j = 0; j < increase_ratio - 1; j++) {
+                    expected.add(expected.get(i));
+                    data.add(data.get(i));
+                }
+            }
+        }
+        return new TrainingDataset(data, expected);
+    }
 }
diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java
index e8174f4be434425d68cfa1251cd77d97b3f8d493..1122429512d9b1f67bda303865fc81b9fed5e466 100644
--- a/src/main/java/be/cylab/java/wowa/training/Utils.java
+++ b/src/main/java/be/cylab/java/wowa/training/Utils.java
@@ -4,6 +4,12 @@ import com.owlike.genson.GenericType;
 import com.owlike.genson.Genson;
 import info.debatty.java.aggregation.WOWA;
 import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.primitives.Pair;
 
 import java.io.FileOutputStream;
 import java.io.IOException;
@@ -534,9 +540,6 @@ final class Utils {
                 new FileOutputStream(filename, true),
                 StandardCharsets.UTF_8)) {
 
-            writer.write("___________________________________________\n");
-            writer.write("|TRIGGER VALUE       :" + trigger_value + "\n");
-            writer.write("|_________________________________________|\n");
             writer.write("AVERAGE\n\n");
             writer.write("Average true positive : "
                     + average_true_positive_counter + "\n");
@@ -569,4 +572,63 @@ final class Utils {
             e.printStackTrace();
         }
     }
+
+    static DataSet createDataSet(
+            final List<List<Double>> data,
+            final List<Double> expected) {
+        if (data.size() != expected.size()) {
+            throw new IllegalArgumentException(
+                    "Data and Expected must have the same size");
+        }
+        double[][] data_array = new double[data.size()][data.get(0).size()];
+        double[][] expected_array = new double[expected.size()][2];
+        for (int i = 0; i < data.size(); i++) {
+            if (expected.get(i) == 1.0) {
+                expected_array[i][0] = 1.0;
+                expected_array[i][1] = 0.0;
+            } else {
+                expected_array[i][0] = 0.0;
+                expected_array[i][1] = 1.0;
+            }
+
+            for (int j = 0; j < data.get(i).size(); j++) {
+                data_array[i][j] = data.get(i).get(j);
+            }
+        }
+        INDArray data_ind = Nd4j.create(data_array);
+        INDArray expected_ind = Nd4j.create(expected_array);
+        return new DataSet(data_ind, expected_ind);
+    }
+
+    static DataSetIterator createDataSetIterator(
+            final List<List<Double>> data,
+            final List<Double> expected,
+            final int batch_size
+            ) {
+
+        List<Pair<double[], double[]>> ds = new ArrayList<>();
+        if (data.size() != expected.size()) {
+            throw new IllegalArgumentException(
+                    "Data and Expected must have the same size");
+        }
+        for (int i = 0; i < data.size(); i++) {
+            if (expected.get(i) == 1.0) {
+                double[] d = convertListDoubleToArrayDouble(data.get(i));
+                double[] e = {1.0, 0.0};
+                Pair<double[], double[]> p
+                        = new Pair<>(d, e);
+                ds.add(p);
+
+            } else {
+                double[] d = convertListDoubleToArrayDouble(data.get(i));
+                double[] e = {0.0, 1.0};
+                Pair<double[], double[]> p
+                        = new Pair<>(d, e);
+                ds.add(p);
+            }
+
+        }
+        DataSetIterator dsi = new DoublesDataSetIterator(ds, batch_size);
+        return dsi;
+    }
 }
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
index fd5d815f5260415d4c216c6a49134e66f358bf64..41bcd04e790770c42e29660142594b4e4f709fdf 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java
@@ -22,66 +22,6 @@ class TrainerTest {
     void testRun() {
     }
 
-    @Test
-    void testIncreaseTrueAlert() {
-        List<List<Double>> data = generateData(100, 5);
-        int original_data_set_size = data.size();
-        List<Double> expected = generateExpectedBinaryClassification(100);
-        int increase_ratio = 10;
-        int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
-        int number_of_no_alert = original_data_set_size - number_of_true_alert;
-        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
-
-        //Check if the length of the dataset us correct
-        //increase_ration * number_of_true_alert + expected.size()
-        assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
-
-        //Check the number of true_alert in the dataset
-        //increase_ration * number_of_true_alert + number_of_true_alert
-        assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
-
-        //Check if each rue alert elements are present (increase_ratio time) in he dataset
-        //Here, we check if each true alert element is present 10 more times than the original one
-        for (int i = 0; i < expected.size(); i++) {
-            if (ds.getExpected().get(i) == 1.0) {
-                int cnt = 0;
-                for (int j = 0; j < ds.getLength(); j++) {
-                    if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
-                        cnt++;
-                    }
-                }
-                assertEquals(increase_ratio, cnt);
-            }
-        }
-
-    }
-
-    @Test
-    void testPrepareFolds() {
-        int number_of_elements = 100;
-        List<List<Double>> data = generateData(number_of_elements,5);
-        List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
-        int increase_ratio = 3;
-        int fold_number = 10;
-        int number_of_alert = (int)(double)Utils.sumListElements(expected);
-        int number_of_no_alert = number_of_elements - number_of_alert;
-        TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
-        List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
-        assertEquals(fold_number, folds.size());
-        for (int i = 0; i < folds.size(); i++) {
-            assertTrue(folds.get(i).getLength()
-                    == (number_of_alert * increase_ratio + number_of_no_alert)
-                    / fold_number || folds.get(i).getLength()
-                    == 1 + (number_of_alert * increase_ratio + number_of_no_alert)
-                    / fold_number);
-            assertTrue(Utils.sumListElements(folds.get(i).getExpected())
-                    == ((number_of_alert * increase_ratio) / fold_number) ||
-                    Utils.sumListElements(folds.get(i).getExpected())
-                            == 1 + ((number_of_alert * increase_ratio) / fold_number)
-                    || Utils.sumListElements(folds.get(i).getExpected())
-                    == 2 + ((number_of_alert * increase_ratio) / fold_number));
-        }
-    }
 
     @Test
     void testFindBestSolution() {
diff --git a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
index a2eda8a8608d96c4f82146a3a909da0c961ce66f..94a5e59b57e4fc2a725de5eb412e7b424baedbf3 100644
--- a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
+++ b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java
@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Random;
 
 import static org.junit.jupiter.api.Assertions.*;
 
@@ -108,4 +109,100 @@ class TrainingDatasetTest {
             }
         }
     }
+
+    @Test
+    void testIncreaseTrueAlert() {
+        List<List<Double>> data = generateData(100, 5);
+        int original_data_set_size = data.size();
+        List<Double> expected = generateExpectedBinaryClassification(100);
+        int increase_ratio = 10;
+        int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = original_data_set_size - number_of_true_alert;
+        TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
+
+        //Check if the length of the dataset us correct
+        //increase_ration * number_of_true_alert + expected.size()
+        assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
+
+        //Check the number of true_alert in the dataset
+        //increase_ration * number_of_true_alert + number_of_true_alert
+        assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
+
+        //Check if each rue alert elements are present (increase_ratio time) in he dataset
+        //Here, we check if each true alert element is present 10 more times than the original one
+        for (int i = 0; i < expected.size(); i++) {
+            if (ds.getExpected().get(i) == 1.0) {
+                int cnt = 0;
+                for (int j = 0; j < ds.getLength(); j++) {
+                    if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
+                        cnt++;
+                    }
+                }
+                assertEquals(increase_ratio, cnt);
+            }
+        }
+
+    }
+
+    @Test
+    void testPrepareFolds() {
+        int number_of_elements = 100;
+        List<List<Double>> data = generateData(number_of_elements,5);
+        List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
+        int increase_ratio = 3;
+        int fold_number = 10;
+        int number_of_alert = (int)(double)Utils.sumListElements(expected);
+        int number_of_no_alert = number_of_elements - number_of_alert;
+        TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
+        List<TrainingDataset> folds = ds.prepareFolds(fold_number);
+        assertEquals(fold_number, folds.size());
+        for (int i = 0; i < folds.size(); i++) {
+            assertTrue(folds.get(i).getLength()
+                    == (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number || folds.get(i).getLength()
+                    == 1 + (number_of_alert * increase_ratio + number_of_no_alert)
+                    / fold_number);
+            assertTrue(Utils.sumListElements(folds.get(i).getExpected())
+                    == ((number_of_alert * increase_ratio) / fold_number) ||
+                    Utils.sumListElements(folds.get(i).getExpected())
+                            == 1 + ((number_of_alert * increase_ratio) / fold_number)
+                    || Utils.sumListElements(folds.get(i).getExpected())
+                    == 2 + ((number_of_alert * increase_ratio) / fold_number));
+        }
+    }
+
+    static List<List<Double>> generateData(final int size, final int weight_number) {
+        Random rnd = new Random(5489);
+        List<List<Double>> data = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            List<Double> vector = new ArrayList<>();
+            for (int j = 0; j < weight_number; j++) {
+                vector.add(rnd.nextDouble());
+            }
+            data.add(vector);
+        }
+        return  data;
+    }
+
+    static List<Double> generateExpected(final int size) {
+        Random rnd = new Random(5768);
+        List<Double> expected = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            expected.add(rnd.nextDouble());
+        }
+        return expected;
+    }
+
+    static List<Double> generateExpectedBinaryClassification(final int size) {
+        Random rnd = new Random(5768);
+        List<Double> expected = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            if (rnd.nextDouble() <= 0.5) {
+                expected.add(new Double(0.0));
+            } else {
+                expected.add(new Double(1.0));
+            }
+        }
+        return  expected;
+    }
 }
\ No newline at end of file