From 0c52e8a9bc38bc358abb0198e31c62cddf2ca6c4 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Fri, 7 Jun 2019 10:56:28 +0200 Subject: [PATCH] Set neural network on specific branch --- pom.xml | 11 -- .../be/cylab/java/wowa/training/MainDL4J.java | 103 ------------------ 2 files changed, 114 deletions(-) delete mode 100644 src/main/java/be/cylab/java/wowa/training/MainDL4J.java diff --git a/pom.xml b/pom.xml index 7db5371..9bbf15c 100644 --- a/pom.xml +++ b/pom.xml @@ -98,17 +98,6 @@ <version>0.0.3</version> </dependency> - <dependency> - <groupId>org.nd4j</groupId> - <artifactId>nd4j-native-platform</artifactId> - <version>0.9.1</version> - </dependency> - - <dependency> - <groupId>org.deeplearning4j</groupId> - <artifactId>deeplearning4j-core</artifactId> - <version>0.9.1</version> - </dependency> </dependencies> diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java deleted file mode 100644 index 3209094..0000000 --- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java +++ /dev/null @@ -1,103 +0,0 @@ -package be.cylab.java.wowa.training; - -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.util.ClassPathResource; -import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.SplitTestAndTrain; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.IOException; - -/** - * Class for learn in neuronal network. - */ -public final class MainDL4J { - /** - * Default constructor. - */ - private MainDL4J() { - - } - - /** - * Class count. - */ - public static final int CLASSES_COUNT = 2; - /** - * Features count. - */ - public static final int FEATURES_COUNT = 5; - - /** - * Main class for deep-learning. - * - * @param args - */ - public static void main(final String[] args) { - - - try (RecordReader record_reader = new CSVRecordReader(0, ',')) { - record_reader.initialize(new FileSplit( - new ClassPathResource("webshell_data.csv").getFile() - )); - DataSetIterator iterator = new RecordReaderDataSetIterator( - record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT); - DataSet all_data = iterator.next(); - all_data.shuffle(); - - DataNormalization normalizer = new NormalizerStandardize(); - normalizer.fit(all_data); - normalizer.transform(all_data); - - SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.75); - DataSet training_data = test_and_train.getTrain(); - DataSet test_data = test_and_train.getTest(); - - MultiLayerConfiguration configuration - = new NeuralNetConfiguration.Builder() - .iterations(1000) - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .learningRate(0.1) - .regularization(true).l2(0.0001) - .list() - .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(10).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(2, new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX) - .nIn(10).nOut(CLASSES_COUNT).build()) - .backprop(true).pretrain(false) - .build(); - - MultiLayerNetwork model = new MultiLayerNetwork(configuration); - model.init(); - model.fit(training_data); - - INDArray output = model.output((test_data.getFeatureMatrix())); - - Evaluation eval = new Evaluation(CLASSES_COUNT); - eval.eval(test_data.getLabels(), output); - System.out.println(eval.stats()); - } catch (IOException e) { - e.printStackTrace(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } -} -- GitLab