diff --git a/java-wowa-training.iml b/java-wowa-training.iml index 08c09f61200f2f2388e08bc4ac406e511612c382..07a0b767d4bdc777f3b37cf6f213126ed7f88d11 100644 --- a/java-wowa-training.iml +++ b/java-wowa-training.iml @@ -32,6 +32,97 @@ <SOURCES /> </library> </orderEntry> + <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" /> + <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" /> + <orderEntry type="library" name="Maven: commons-io:commons-io:2.4" level="project" /> + <orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.7" level="project" /> + <orderEntry type="library" name="Maven: joda-time:joda-time:2.9.2" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:jackson:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.yaml:snakeyaml:1.12" level="project" /> + <orderEntry type="library" name="Maven: org.codehaus.woodstox:stax2-api:3.1.4" level="project" /> + <orderEntry type="library" name="Maven: org.projectlombok:lombok:1.16.10" level="project" /> + <orderEntry type="library" name="Maven: org.freemarker:freemarker:2.3.23" level="project" /> + <orderEntry type="library" name="Maven: org.reflections:reflections:0.9.10" level="project" /> + <orderEntry type="library" name="Maven: org.javassist:javassist:3.19.0-GA" level="project" /> + <orderEntry type="library" name="Maven: com.google.code.findbugs:annotations:2.0.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-common:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: com.github.stephenc.findbugs:findbugs-annotations:1.3.9-1" level="project" /> + <orderEntry type="library" name="Maven: com.clearspring.analytics:stream:2.7.0" level="project" /> + <orderEntry type="library" name="Maven: it.unimi.dsi:fastutil:6.5.7" level="project" /> + <orderEntry type="library" name="Maven: net.sf.opencsv:opencsv:2.3" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-api:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-buffer:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-context:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: net.ericaro:neoitertools:1.0.0" level="project" /> + <orderEntry type="library" name="Maven: junit:junit:4.8.2" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-native:linux-x86_64:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco:javacpp:1.3.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:0.2.19-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:openblas:linux-x86_64:0.2.19-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-native-api:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-nn:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-jackson:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.nd4j:nd4j-base64:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: commons-net:commons-net:3.1" level="project" /> + <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-core:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.deeplearning4j:nearestneighbor-core:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-modelimport:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5-platform:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-x86_64:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:linux-ppc64le:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:macosx-x86_64:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:hdf5:windows-x86_64:1.10.0-patch1-1.3" level="project" /> + <orderEntry type="library" name="Maven: com.google.guava:guava:20.0" level="project" /> + <orderEntry type="library" name="Maven: org.datavec:datavec-data-image:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: com.github.jai-imageio:jai-imageio-core:1.3.0" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-jpeg:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-core:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-metadata:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-lang:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-io:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.common:common-image:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-tiff:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-psd:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: com.twelvemonkeys.imageio:imageio-bmp:3.1.1" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco:javacv:1.3.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:ffmpeg:3.2.1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flycapture:2.9.3.43-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libdc1394:2.2.4-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect:0.5.3-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:libfreenect2:0.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:librealsense:1.9.6-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:videoinput:0.200-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:artoolkitplus:2.3.1-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:flandmark:1.07-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv-platform:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-arm:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:android-x86:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-x86_64:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-armhf:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:linux-ppc64le:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:macosx-x86_64:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:opencv:windows-x86_64:3.2.0-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica-platform:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-arm:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:android-x86:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-x86_64:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-armhf:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:linux-ppc64le:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:macosx-x86_64:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.bytedeco.javacpp-presets:leptonica:windows-x86_64:1.73-1.3" level="project" /> + <orderEntry type="library" name="Maven: org.deeplearning4j:deeplearning4j-ui-components:0.9.1" level="project" /> + <orderEntry type="library" name="Maven: commons-codec:commons-codec:1.10" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.junit.jupiter:junit-jupiter-api:5.3.1" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.apiguardian:apiguardian-api:1.0.0" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: org.opentest4j:opentest4j:1.1.1" level="project" /> diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java index d8a5b18a6ac2110b44ccec1bc938a8e6f14e6400..727407e9b63db631aba74d0f3b7ed58796a50189 100644 --- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java +++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java @@ -43,7 +43,7 @@ public final class MainDL4J { if (args[3].matches("RELU")) { activation_function = Activation.RELU; - } else if (args[3].matches("SIGMOID")) { + } else if (args[3].matches("SIGMOID")) { activation_function = Activation.SIGMOID; } else if (args[3].matches("TANH")) { activation_function = Activation.TANH; @@ -52,16 +52,14 @@ public final class MainDL4J { "Not correct activation function"); } - LearningNeuralNetwork lnn - = new LearningNeuralNetwork(file_name, - learning_rate, - optimization_algorithm, - activation_function); - for (int neurons_number = 5; neurons_number < 50; neurons_number++) { - System.out.println("Number of neurons in hidden layer : " - + neurons_number); - lnn.learning(neurons_number); - } - + NeuralNetwork nn = new NeuralNetwork(20, learning_rate, + optimization_algorithm, activation_function); + nn.run("./ressources/webshell_data.json", + "./ressources/webshell_expected.json"); + /* + NeuralNetwork nn = new NeuralNetwork(20, learning_rate, + optimization_algorithm, activation_function); + nn.runCSV("webshell_data.csv"); + */ } } diff --git a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java similarity index 54% rename from src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java rename to src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java index 7328913d5b67cd2845395633e1ffff9b547b04eb..6a6339aaf5a52063437ce77d6563823050923663 100644 --- a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java +++ b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java @@ -5,6 +5,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.ROC; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -22,14 +23,17 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.primitives.Pair; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; /** * Class to test neural network learning. */ -public final class LearningNeuralNetwork { +public final class NeuralNetwork { /** * Class count. @@ -41,47 +45,28 @@ public final class LearningNeuralNetwork { public static final int FEATURES_COUNT = 5; - private DataSet training_data; - private DataSet testing_data; + private int neurons_number; private double learning_rate; private OptimizationAlgorithm algorithm; private Activation activation_function; /** - * Constructor. - * - * @param file_name + * @param neurons_number + * @param learning_rate + * @param algorithm + * @param activation_function */ - public LearningNeuralNetwork(final String file_name, - final double learning_rate, - final OptimizationAlgorithm algorithm, - final Activation activation_function + public NeuralNetwork(final int neurons_number, + final double learning_rate, + final OptimizationAlgorithm algorithm, + final Activation activation_function ) { - try (RecordReader record_reader = new CSVRecordReader(0, ',')) { - record_reader.initialize(new FileSplit( - new ClassPathResource(file_name).getFile() - )); - DataSetIterator iterator = new RecordReaderDataSetIterator( - record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT); - DataSet all_data = iterator.next(); - all_data.shuffle(); - DataNormalization normalizer = new NormalizerStandardize(); - normalizer.fit(all_data); - normalizer.transform(all_data); - - SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.90); - this.training_data = test_and_train.getTrain(); - this.testing_data = test_and_train.getTest(); - } catch (IOException e) { - e.printStackTrace(); - } catch (InterruptedException e) { - e.printStackTrace(); - } this.learning_rate = learning_rate; this.algorithm = algorithm; this.activation_function = activation_function; + this.neurons_number = neurons_number; } /** @@ -89,7 +74,10 @@ public final class LearningNeuralNetwork { * * @param neurons_number */ - public void learning(final int neurons_number) { + MultiLayerNetwork learning( + final int neurons_number, + final DataSet learning, + final DataSet testing) { MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() @@ -115,16 +103,100 @@ public final class LearningNeuralNetwork { MultiLayerNetwork model = new MultiLayerNetwork(configuration); model.init(); - model.fit(this.training_data); + model.fit(learning); INDArray output - = model.output((this.testing_data.getFeatureMatrix())); + = model.output((testing.getFeatureMatrix())); Evaluation eval = new Evaluation(CLASSES_COUNT); - eval.eval(this.testing_data.getLabels(), output); + eval.eval(testing.getLabels(), output); System.out.println(eval.stats()); ROC roc = new ROC(CLASSES_COUNT); - roc.eval(this.testing_data.getLabels(), output); + roc.eval(testing.getLabels(), output); System.out.println(roc.stats()); + return model; + } + + /** + * @param data + * @param expected + * @return + */ + public MultiLayerNetwork run( + final List<List<Double>> data, + final List<Double> expected) { + if (data.size() != expected.size()) { + throw new IllegalArgumentException( + "Data and Expected must have the same size"); + } + List<Pair<double[], double[]>> pair = new ArrayList<>(); + for (int i = 0; i < expected.size(); i++) { + double[] da = new double[data.get(i).size()]; + double[] ex = {1.0, 0.0}; + for (int j = 0; j < data.get(i).size(); j++) { + da[j] = data.get(i).get(j); + } + Pair<double[], double[]> p = new Pair<>(); + p.setKey(da); + p.setValue(ex); + pair.add(p); + } + DataSetIterator doubles_dataset = new DoublesDataSetIterator(pair, + 12468); + DataSet all_data = doubles_dataset.next(); + all_data.shuffle(); + DataNormalization normalizer = new NormalizerStandardize(); + normalizer.fit(all_data); + normalizer.transform(all_data); + SplitTestAndTrain test_train = all_data.splitTestAndTrain(0.8); + DataSet training_data = test_train.getTrain(); + DataSet testing_data = test_train.getTest(); + return this.learning(neurons_number, training_data, testing_data); + + } + + /** + * @param data_filename + * @param expected_filename + * @return + */ + public MultiLayerNetwork run( + final String data_filename, + final String expected_filename) { + List<List<Double>> data + = Utils.convertJsonToDataForTrainer(data_filename); + List<Double> expected + = Utils.convertJsonToExpectedForTrainer(expected_filename); + return this.run(data, expected); + } + + /** + * @param filename + * @return + */ + public MultiLayerNetwork runCSV(final String filename) { + try (RecordReader record_reader = new CSVRecordReader(0, ',')) { + record_reader.initialize(new FileSplit( + new ClassPathResource(filename).getFile() + )); + DataSetIterator iterator = new RecordReaderDataSetIterator( + record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT); + DataSet all_data = iterator.next(); + all_data.shuffle(); + DataNormalization normalizer = new NormalizerStandardize(); + normalizer.fit(all_data); + normalizer.transform(all_data); + + SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.80); + DataSet training_data = test_and_train.getTrain(); + DataSet testing_data = test_and_train.getTest(); + return learning(neurons_number, training_data, testing_data); + + } catch (IOException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + return null; } }