Skip to content
Snippets Groups Projects
Commit bb2cdc96 authored by a.croix's avatar a.croix
Browse files

Modification MainDL4J for parametric study

parent f20504e8
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2314 failed
......@@ -42,13 +42,15 @@
<orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" />
<orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" />
<orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" />
<orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" />
<orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" />
<orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" />
<orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" />
<orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" />
<orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" />
<orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" />
<orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
......
......@@ -4,6 +4,12 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
/**
* Class for learn in neuronal network.
*/
......@@ -21,20 +27,24 @@ public final class MainDL4J {
* @param args
*/
public static void main(final String[] args) {
//String file_name = args[0];
double learning_rate = Double.parseDouble(args[1]);
int begin_neuron_number = Integer.parseInt(args[0]);
int neuron_number_end = Integer.parseInt(args[1]);
String data_file = args[5];
String expected_file = args[6];
double learning_rate = Double.parseDouble(args[2]);
OptimizationAlgorithm optimization_algorithm;
Activation activation_function;
int fold_number = 10;
int increase_ratio = 10;
if (args[2].matches("CONJUGATE_GRADIENT")) {
if (args[3].matches("CONJUGATE_GRADIENT")) {
optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
} else if (args[2].matches("LBFGS")) {
} else if (args[3].matches("LBFGS")) {
optimization_algorithm = OptimizationAlgorithm.LBFGS;
} else if (args[2].matches("LINE_GRADIENT_DESCENT")) {
} else if (args[3].matches("LINE_GRADIENT_DESCENT")) {
optimization_algorithm
= OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
} else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) {
} else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) {
optimization_algorithm
= OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
} else {
......@@ -42,32 +52,46 @@ public final class MainDL4J {
"Not correct optimization algorithm");
}
if (args[3].matches("RELU")) {
if (args[4].matches("RELU")) {
activation_function = Activation.RELU;
} else if (args[3].matches("SIGMOID")) {
} else if (args[4].matches("SIGMOID")) {
activation_function = Activation.SIGMOID;
} else if (args[3].matches("TANH")) {
} else if (args[4].matches("TANH")) {
activation_function = Activation.TANH;
} else {
throw new IllegalArgumentException(
"Not correct activation function");
}
HyperParameters parameters = new HyperParameters(20, learning_rate,
optimization_algorithm, activation_function);
NeuralNetwork nn = new NeuralNetwork(parameters);
MultiLayerNetwork mnn = nn.run("./ressources/webshell_data.json",
"./ressources/webshell_expected.json");
System.out.println("Begin K-fold cross-validation");
for (int neuron_number = begin_neuron_number;
neuron_number < neuron_number_end; neuron_number++) {
double average_auc = 0;
HyperParameters parameters
= new HyperParameters(neuron_number, learning_rate,
optimization_algorithm, activation_function);
NeuralNetwork nn = new NeuralNetwork(parameters);
//HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
// "./ressources/webshell_data.json",
// "./ressources/webshell_expected.json",
// 10,
// 10);
HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
data_file,
expected_file,
fold_number,
increase_ratio);
for (Double d : map.values()) {
average_auc = average_auc + d;
}
try (OutputStreamWriter writer = new OutputStreamWriter(
new FileOutputStream("Synthesis_average_AUC.txt", true),
StandardCharsets.UTF_8)) {
writer.write("Neuron number : " + neuron_number
+ "Learning rate : " + learning_rate
+ "Algorithm : " + args[3]
+ "Activation function : " + args[4]
+ "Average AUC = " + (average_auc / fold_number));
} catch (IOException e) {
e.printStackTrace();
}
}
NeuralNetwork nnn = new NeuralNetwork(parameters);
nnn.runCSV("webshell_data.csv");
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment