Skip to content
Snippets Groups Projects
Commit f11398f3 authored by a.croix's avatar a.croix
Browse files

Add run methods to use a Neural Network with the same structure as...

Add run methods to use a Neural Network with the same structure as wowa-training. The goal is to use the same dataset in each learning to compare the performance
parent 6d9425b9
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2234 failed
...@@ -32,6 +32,97 @@ ...@@ -32,6 +32,97 @@
<SOURCES /> <SOURCES />
</library> </library>
</orderEntry> </orderEntry>
<orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
<orderEntry type="library" name="Maven: commons-io:commons-io:2.4" level="project" />
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.7" level="project" />
<orderEntry type="library" name="Maven: joda-time:joda-time:2.9.2" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:jackson:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.yaml:snakeyaml:1.12" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.woodstox:stax2-api:3.1.4" level="project" />
<orderEntry type="library" name="Maven: org.projectlombok:lombok:1.16.10" level="project" />
<orderEntry type="library" name="Maven: org.freemarker:freemarker:2.3.23" level="project" />
<orderEntry type="library" name="Maven: org.reflections:reflections:0.9.10" level="project" />
<orderEntry type="library" name="Maven: org.javassist:javassist:3.19.0-GA" level="project" />
<orderEntry type="library" name="Maven: com.google.code.findbugs:annotations:2.0.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-common:0.9.1" level="project" />
<orderEntry type="library" name="Maven: com.github.stephenc.findbugs:findbugs-annotations:1.3.9-1" level="project" />
<orderEntry type="library" name="Maven: com.clearspring.analytics:stream:2.7.0" level="project" />
<orderEntry type="library" name="Maven: it.unimi.dsi:fastutil:6.5.7" level="project" />
<orderEntry type="library" name="Maven: net.sf.opencsv:opencsv:2.3" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-api:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-buffer:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-context:0.9.1" level="project" />
<orderEntry type="library" name="Maven: net.ericaro:neoitertools:1.0.0" level="project" />
<orderEntry type="library" name="Maven: junit:junit:4.8.2" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-native:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-native:linux-x86_64:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco:javacpp:1.3.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:0.2.19-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:linux-x86_64:0.2.19-1.3" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-native-api:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-nn:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-jackson:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.nd4j:nd4j-base64:0.9.1" level="project" />
<orderEntry type="library" name="Maven: commons-net:commons-net:3.1" level="project" />
<orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-core:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.deeplearning4j:nearestneighbor-core:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-modelimport:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5-platform:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86_64:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-ppc64le:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:macosx-x86_64:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86_64:1.10.0-patch1-1.3" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:guava:20.0" level="project" />
<orderEntry type="library" name="Maven: org.datavec:datavec-data-image:0.9.1" level="project" />
<orderEntry type="library" name="Maven: com.github.jai-imageio:jai-imageio-core:1.3.0" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-jpeg:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-core:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-metadata:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-lang:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-io:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-image:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-tiff:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-psd:3.1.1" level="project" />
<orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-bmp:3.1.1" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco:javacv:1.3.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:ffmpeg:3.2.1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flycapture:2.9.3.43-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libdc1394:2.2.4-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect:0.5.3-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect2:0.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:librealsense:1.9.6-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:videoinput:0.200-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:artoolkitplus:2.3.1-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flandmark:1.07-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv-platform:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-arm:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-x86:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86_64:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-armhf:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-ppc64le:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:macosx-x86_64:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86_64:3.2.0-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica-platform:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-arm:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-x86:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86_64:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-armhf:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-ppc64le:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:macosx-x86_64:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86_64:1.73-1.3" level="project" />
<orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-ui-components:0.9.1" level="project" />
<orderEntry type="library" name="Maven: commons-codec:commons-codec:1.10" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" />
......
...@@ -43,7 +43,7 @@ public final class MainDL4J { ...@@ -43,7 +43,7 @@ public final class MainDL4J {
if (args[3].matches("RELU")) { if (args[3].matches("RELU")) {
activation_function = Activation.RELU; activation_function = Activation.RELU;
} else if (args[3].matches("SIGMOID")) { } else if (args[3].matches("SIGMOID")) {
activation_function = Activation.SIGMOID; activation_function = Activation.SIGMOID;
} else if (args[3].matches("TANH")) { } else if (args[3].matches("TANH")) {
activation_function = Activation.TANH; activation_function = Activation.TANH;
...@@ -52,16 +52,14 @@ public final class MainDL4J { ...@@ -52,16 +52,14 @@ public final class MainDL4J {
"Not correct activation function"); "Not correct activation function");
} }
LearningNeuralNetwork lnn NeuralNetwork nn = new NeuralNetwork(20, learning_rate,
= new LearningNeuralNetwork(file_name, optimization_algorithm, activation_function);
learning_rate, nn.run("./ressources/webshell_data.json",
optimization_algorithm, "./ressources/webshell_expected.json");
activation_function); /*
for (int neurons_number = 5; neurons_number < 50; neurons_number++) { NeuralNetwork nn = new NeuralNetwork(20, learning_rate,
System.out.println("Number of neurons in hidden layer : " optimization_algorithm, activation_function);
+ neurons_number); nn.runCSV("webshell_data.csv");
lnn.learning(neurons_number); */
}
} }
} }
...@@ -5,6 +5,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; ...@@ -5,6 +5,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource; import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC; import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
...@@ -22,14 +23,17 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; ...@@ -22,14 +23,17 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/** /**
* Class to test neural network learning. * Class to test neural network learning.
*/ */
public final class LearningNeuralNetwork { public final class NeuralNetwork {
/** /**
* Class count. * Class count.
...@@ -41,47 +45,28 @@ public final class LearningNeuralNetwork { ...@@ -41,47 +45,28 @@ public final class LearningNeuralNetwork {
public static final int FEATURES_COUNT = 5; public static final int FEATURES_COUNT = 5;
private DataSet training_data; private int neurons_number;
private DataSet testing_data;
private double learning_rate; private double learning_rate;
private OptimizationAlgorithm algorithm; private OptimizationAlgorithm algorithm;
private Activation activation_function; private Activation activation_function;
/** /**
* Constructor. * @param neurons_number
* * @param learning_rate
* @param file_name * @param algorithm
* @param activation_function
*/ */
public LearningNeuralNetwork(final String file_name, public NeuralNetwork(final int neurons_number,
final double learning_rate, final double learning_rate,
final OptimizationAlgorithm algorithm, final OptimizationAlgorithm algorithm,
final Activation activation_function final Activation activation_function
) { ) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource(file_name).getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.90);
this.training_data = test_and_train.getTrain();
this.testing_data = test_and_train.getTest();
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
this.learning_rate = learning_rate; this.learning_rate = learning_rate;
this.algorithm = algorithm; this.algorithm = algorithm;
this.activation_function = activation_function; this.activation_function = activation_function;
this.neurons_number = neurons_number;
} }
/** /**
...@@ -89,7 +74,10 @@ public final class LearningNeuralNetwork { ...@@ -89,7 +74,10 @@ public final class LearningNeuralNetwork {
* *
* @param neurons_number * @param neurons_number
*/ */
public void learning(final int neurons_number) { MultiLayerNetwork learning(
final int neurons_number,
final DataSet learning,
final DataSet testing) {
MultiLayerConfiguration configuration MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder() = new NeuralNetConfiguration.Builder()
...@@ -115,16 +103,100 @@ public final class LearningNeuralNetwork { ...@@ -115,16 +103,100 @@ public final class LearningNeuralNetwork {
MultiLayerNetwork model = new MultiLayerNetwork(configuration); MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init(); model.init();
model.fit(this.training_data); model.fit(learning);
INDArray output INDArray output
= model.output((this.testing_data.getFeatureMatrix())); = model.output((testing.getFeatureMatrix()));
Evaluation eval = new Evaluation(CLASSES_COUNT); Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(this.testing_data.getLabels(), output); eval.eval(testing.getLabels(), output);
System.out.println(eval.stats()); System.out.println(eval.stats());
ROC roc = new ROC(CLASSES_COUNT); ROC roc = new ROC(CLASSES_COUNT);
roc.eval(this.testing_data.getLabels(), output); roc.eval(testing.getLabels(), output);
System.out.println(roc.stats()); System.out.println(roc.stats());
return model;
}
/**
* @param data
* @param expected
* @return
*/
public MultiLayerNetwork run(
final List<List<Double>> data,
final List<Double> expected) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and Expected must have the same size");
}
List<Pair<double[], double[]>> pair = new ArrayList<>();
for (int i = 0; i < expected.size(); i++) {
double[] da = new double[data.get(i).size()];
double[] ex = {1.0, 0.0};
for (int j = 0; j < data.get(i).size(); j++) {
da[j] = data.get(i).get(j);
}
Pair<double[], double[]> p = new Pair<>();
p.setKey(da);
p.setValue(ex);
pair.add(p);
}
DataSetIterator doubles_dataset = new DoublesDataSetIterator(pair,
12468);
DataSet all_data = doubles_dataset.next();
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_train = all_data.splitTestAndTrain(0.8);
DataSet training_data = test_train.getTrain();
DataSet testing_data = test_train.getTest();
return this.learning(neurons_number, training_data, testing_data);
}
/**
* @param data_filename
* @param expected_filename
* @return
*/
public MultiLayerNetwork run(
final String data_filename,
final String expected_filename) {
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(data_filename);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(expected_filename);
return this.run(data, expected);
}
/**
* @param filename
* @return
*/
public MultiLayerNetwork runCSV(final String filename) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource(filename).getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.80);
DataSet training_data = test_and_train.getTrain();
DataSet testing_data = test_and_train.getTest();
return learning(neurons_number, training_data, testing_data);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment