Skip to content
Snippets Groups Projects
Commit 0c52e8a9 authored by a.croix's avatar a.croix
Browse files

Set neural network on specific branch

parent 45204932
No related branches found
No related tags found
No related merge requests found
Pipeline #1821 passed
......@@ -98,17 +98,6 @@
<version>0.0.3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>0.9.1</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>0.9.1</version>
</dependency>
</dependencies>
......
package be.cylab.java.wowa.training;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException;
/**
* Class for learn in neuronal network.
*/
public final class MainDL4J {
/**
* Default constructor.
*/
private MainDL4J() {
}
/**
* Class count.
*/
public static final int CLASSES_COUNT = 2;
/**
* Features count.
*/
public static final int FEATURES_COUNT = 5;
/**
* Main class for deep-learning.
*
* @param args
*/
public static void main(final String[] args) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource("webshell_data.csv").getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.75);
DataSet training_data = test_and_train.getTrain();
DataSet test_data = test_and_train.getTest();
MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder()
.iterations(1000)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(10).build())
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(10).nOut(CLASSES_COUNT).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(training_data);
INDArray output = model.output((test_data.getFeatureMatrix()));
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(test_data.getLabels(), output);
System.out.println(eval.stats());
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment