From bb2cdc963686e0b68dc6f1160cc00920adad3541 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Fri, 13 Sep 2019 11:17:35 +0200 Subject: [PATCH] Modification MainDL4J for parametric study --- java-wowa-training.iml | 4 +- .../be/cylab/java/wowa/training/MainDL4J.java | 70 +++++++++++++------ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/java-wowa-training.iml b/java-wowa-training.iml index 08c09f6..e65012b 100644 --- a/java-wowa-training.iml +++ b/java-wowa-training.iml @@ -42,13 +42,15 @@ <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" /> <orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" /> - <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" /> + <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" /> <orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" /> <orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" /> <orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" /> <orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" /> + <orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" /> + <orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" /> <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" /> diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java index cf2f3ec..20f2bd6 100644 --- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java +++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java @@ -4,6 +4,12 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.activations.Activation; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; + /** * Class for learn in neuronal network. */ @@ -21,20 +27,24 @@ public final class MainDL4J { * @param args */ public static void main(final String[] args) { - - //String file_name = args[0]; - double learning_rate = Double.parseDouble(args[1]); + int begin_neuron_number = Integer.parseInt(args[0]); + int neuron_number_end = Integer.parseInt(args[1]); + String data_file = args[5]; + String expected_file = args[6]; + double learning_rate = Double.parseDouble(args[2]); OptimizationAlgorithm optimization_algorithm; Activation activation_function; + int fold_number = 10; + int increase_ratio = 10; - if (args[2].matches("CONJUGATE_GRADIENT")) { + if (args[3].matches("CONJUGATE_GRADIENT")) { optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT; - } else if (args[2].matches("LBFGS")) { + } else if (args[3].matches("LBFGS")) { optimization_algorithm = OptimizationAlgorithm.LBFGS; - } else if (args[2].matches("LINE_GRADIENT_DESCENT")) { + } else if (args[3].matches("LINE_GRADIENT_DESCENT")) { optimization_algorithm = OptimizationAlgorithm.LINE_GRADIENT_DESCENT; - } else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) { + } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) { optimization_algorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; } else { @@ -42,32 +52,46 @@ public final class MainDL4J { "Not correct optimization algorithm"); } - if (args[3].matches("RELU")) { + if (args[4].matches("RELU")) { activation_function = Activation.RELU; - } else if (args[3].matches("SIGMOID")) { + } else if (args[4].matches("SIGMOID")) { activation_function = Activation.SIGMOID; - } else if (args[3].matches("TANH")) { + } else if (args[4].matches("TANH")) { activation_function = Activation.TANH; } else { throw new IllegalArgumentException( "Not correct activation function"); } - HyperParameters parameters = new HyperParameters(20, learning_rate, - optimization_algorithm, activation_function); - NeuralNetwork nn = new NeuralNetwork(parameters); - MultiLayerNetwork mnn = nn.run("./ressources/webshell_data.json", - "./ressources/webshell_expected.json"); - System.out.println("Begin K-fold cross-validation"); + for (int neuron_number = begin_neuron_number; + neuron_number < neuron_number_end; neuron_number++) { + double average_auc = 0; + HyperParameters parameters + = new HyperParameters(neuron_number, learning_rate, + optimization_algorithm, activation_function); + NeuralNetwork nn = new NeuralNetwork(parameters); - //HashMap<MultiLayerNetwork, Double> map = nn.runKFold( - // "./ressources/webshell_data.json", - // "./ressources/webshell_expected.json", - // 10, - // 10); + HashMap<MultiLayerNetwork, Double> map = nn.runKFold( + data_file, + expected_file, + fold_number, + increase_ratio); + for (Double d : map.values()) { + average_auc = average_auc + d; + } + try (OutputStreamWriter writer = new OutputStreamWriter( + new FileOutputStream("Synthesis_average_AUC.txt", true), + StandardCharsets.UTF_8)) { + writer.write("Neuron number : " + neuron_number + + "Learning rate : " + learning_rate + + "Algorithm : " + args[3] + + "Activation function : " + args[4] + + "Average AUC = " + (average_auc / fold_number)); + } catch (IOException e) { + e.printStackTrace(); + } + } - NeuralNetwork nnn = new NeuralNetwork(parameters); - nnn.runCSV("webshell_data.csv"); } } -- GitLab