Skip to content
Snippets Groups Projects
Commit 54dc6aeb authored by a.croix's avatar a.croix
Browse files

All learning used the same testinf dataset et learning dataset. Better to...

All learning used the same testinf dataset et learning dataset. Better to compare performance of hyper-parameters
parent c9ffea73
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #1830 passed
......@@ -24,15 +24,11 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException;
/**
* Class to test learning.
* Class to test neural network learning.
*/
public final class LearningNeuralNetwork {
private LearningNeuralNetwork() {
}
/**
* Class count.
*/
......@@ -43,72 +39,73 @@ public final class LearningNeuralNetwork {
public static final int FEATURES_COUNT = 5;
/**
* Default constructor.
*/
private DataSet training_data;
private DataSet testing_data;
/**
* Test.
*
* @param neurons_number
* Constructor.
* @param file_name
*/
public static void learning(final int neurons_number) {
public LearningNeuralNetwork(final String file_name) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource("webshell_data.csv").getFile()
new ClassPathResource(file_name).getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
for (int i = 5; i < 50; i++) {
System.out.println("Number of neurons in hidden layer : "
+ i);
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_and_train
= all_data.splitTestAndTrain(0.80);
DataSet training_data = test_and_train.getTrain();
DataSet test_data = test_and_train.getTest();
SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.80);
this.training_data = test_and_train.getTrain();
this.testing_data = test_and_train.getTest();
MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder()
.iterations(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT)
.nOut(i).build())
.layer(1, new DenseLayer.Builder().nIn(i)
.nOut(i).build())
.layer(2, new OutputLayer.Builder(LossFunctions
.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(i)
.nOut(CLASSES_COUNT).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(training_data);
INDArray output = model.output((test_data.getFeatureMatrix()));
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(test_data.getLabels(), output);
System.out.println(eval.stats());
}
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* Method to build a simple neural network, train and test it.
*
* @param neurons_number
*/
public void learning(final int neurons_number) {
MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder()
.iterations(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT)
.nOut(neurons_number).build())
.layer(1, new DenseLayer.Builder().nIn(neurons_number)
.nOut(neurons_number).build())
.layer(2, new OutputLayer.Builder(LossFunctions
.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(neurons_number)
.nOut(CLASSES_COUNT).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(this.training_data);
INDArray output
= model.output((this.testing_data.getFeatureMatrix()));
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(this.testing_data.getLabels(), output);
System.out.println(eval.stats());
}
}
......@@ -18,8 +18,10 @@ public final class MainDL4J {
*/
public static void main(final String[] args) {
LearningNeuralNetwork lnn
= new LearningNeuralNetwork("webshell_data.csv");
for (int neurons_number = 5; neurons_number < 50; neurons_number++) {
LearningNeuralNetwork.learning(neurons_number);
lnn.learning(neurons_number);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment