From bb2cdc963686e0b68dc6f1160cc00920adad3541 Mon Sep 17 00:00:00 2001
From: Alex <croix.alexandre@gmail.com>
Date: Fri, 13 Sep 2019 11:17:35 +0200
Subject: [PATCH] Modification MainDL4J for parametric study

---
 java-wowa-training.iml                        |  4 +-
 .../be/cylab/java/wowa/training/MainDL4J.java | 70 +++++++++++++------
 2 files changed, 50 insertions(+), 24 deletions(-)

diff --git a/java-wowa-training.iml b/java-wowa-training.iml
index 08c09f6..e65012b 100644
--- a/java-wowa-training.iml
+++ b/java-wowa-training.iml
@@ -42,13 +42,15 @@
     <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" />
     <orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" />
-    <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" />
+    <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" />
     <orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" />
     <orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" />
     <orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" />
     <orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" />
+    <orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" />
+    <orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" />
     <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
     <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
index cf2f3ec..20f2bd6 100644
--- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
+++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java
@@ -4,6 +4,12 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 import org.nd4j.linalg.activations.Activation;
 
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+
 /**
  * Class for learn in neuronal network.
  */
@@ -21,20 +27,24 @@ public final class MainDL4J {
      * @param args
      */
     public static void main(final String[] args) {
-
-        //String file_name = args[0];
-        double learning_rate = Double.parseDouble(args[1]);
+        int begin_neuron_number = Integer.parseInt(args[0]);
+        int neuron_number_end = Integer.parseInt(args[1]);
+        String data_file = args[5];
+        String expected_file = args[6];
+        double learning_rate = Double.parseDouble(args[2]);
         OptimizationAlgorithm optimization_algorithm;
         Activation activation_function;
+        int fold_number = 10;
+        int increase_ratio = 10;
 
-        if (args[2].matches("CONJUGATE_GRADIENT")) {
+        if (args[3].matches("CONJUGATE_GRADIENT")) {
             optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
-        } else if (args[2].matches("LBFGS")) {
+        } else if (args[3].matches("LBFGS")) {
             optimization_algorithm = OptimizationAlgorithm.LBFGS;
-        } else if (args[2].matches("LINE_GRADIENT_DESCENT")) {
+        } else if (args[3].matches("LINE_GRADIENT_DESCENT")) {
             optimization_algorithm
                     = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
-        } else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) {
+        } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) {
             optimization_algorithm
                     = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
         } else {
@@ -42,32 +52,46 @@ public final class MainDL4J {
                     "Not correct optimization algorithm");
         }
 
-        if (args[3].matches("RELU")) {
+        if (args[4].matches("RELU")) {
             activation_function = Activation.RELU;
-        } else if (args[3].matches("SIGMOID")) {
+        } else if (args[4].matches("SIGMOID")) {
             activation_function = Activation.SIGMOID;
-        } else if (args[3].matches("TANH")) {
+        } else if (args[4].matches("TANH")) {
             activation_function = Activation.TANH;
         } else {
             throw new IllegalArgumentException(
                     "Not correct activation function");
         }
 
-        HyperParameters parameters = new HyperParameters(20, learning_rate,
-                optimization_algorithm, activation_function);
-        NeuralNetwork nn = new NeuralNetwork(parameters);
-        MultiLayerNetwork mnn = nn.run("./ressources/webshell_data.json",
-                "./ressources/webshell_expected.json");
-        System.out.println("Begin K-fold cross-validation");
+        for (int neuron_number = begin_neuron_number;
+             neuron_number < neuron_number_end; neuron_number++) {
+            double average_auc = 0;
+            HyperParameters parameters
+                    = new HyperParameters(neuron_number, learning_rate,
+                    optimization_algorithm, activation_function);
+            NeuralNetwork nn = new NeuralNetwork(parameters);
 
-        //HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
-        //        "./ressources/webshell_data.json",
-        //        "./ressources/webshell_expected.json",
-        //        10,
-        //        10);
+            HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
+                    data_file,
+                    expected_file,
+                    fold_number,
+                    increase_ratio);
+            for (Double d : map.values()) {
+                average_auc = average_auc + d;
+            }
+            try (OutputStreamWriter writer = new OutputStreamWriter(
+                    new FileOutputStream("Synthesis_average_AUC.txt", true),
+                    StandardCharsets.UTF_8)) {
+                writer.write("Neuron number : " + neuron_number
+                        + "Learning rate : " + learning_rate
+                        + "Algorithm : " + args[3]
+                        + "Activation function : " + args[4]
+                        + "Average AUC = " + (average_auc / fold_number));
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
 
-        NeuralNetwork nnn = new NeuralNetwork(parameters);
-        nnn.runCSV("webshell_data.csv");
 
     }
 }
-- 
GitLab