Skip to content
Snippets Groups Projects
Commit 6ecc8c06 authored by a.croix's avatar a.croix
Browse files

Add parameters for neural network learning

parent 31bd610c
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #1936 passed
......@@ -6,6 +6,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
......@@ -41,12 +42,20 @@ public final class LearningNeuralNetwork {
private DataSet training_data;
private DataSet testing_data;
private double learning_rate;
private OptimizationAlgorithm algorithm;
private Activation activation_function;
/**
* Constructor.
* @param file_name
*/
public LearningNeuralNetwork(final String file_name) {
public LearningNeuralNetwork(final String file_name,
final double learning_rate,
final OptimizationAlgorithm algorithm,
final Activation activation_function
) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource(file_name).getFile()
......@@ -68,6 +77,9 @@ public final class LearningNeuralNetwork {
} catch (InterruptedException e) {
e.printStackTrace();
}
this.learning_rate = learning_rate;
this.algorithm = algorithm;
this.activation_function = activation_function;
}
/**
......@@ -80,9 +92,11 @@ public final class LearningNeuralNetwork {
MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder()
.iterations(1000)
.activation(Activation.RELU)
.activation(this.activation_function)
.optimizationAlgo(
this.algorithm)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.learningRate(this.learning_rate)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT)
......
package be.cylab.java.wowa.training;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.nd4j.linalg.activations.Activation;
/**
* Class for learn in neuronal network.
*/
......@@ -18,8 +21,42 @@ public final class MainDL4J {
*/
public static void main(final String[] args) {
String file_name = args[0];
double learning_rate = Double.parseDouble(args[1]);
OptimizationAlgorithm optimization_algorithm;
Activation activation_function;
if (args[2] == "CONJUGATE_GRADIENT") {
optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
} else if (args[2] == "LBFGS") {
optimization_algorithm = OptimizationAlgorithm.LBFGS;
} else if (args[2] == "LINE_GRADIENT_DESCENT") {
optimization_algorithm
= OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
} else if (args[2] == "STOCHASTIC_GRADIENT_DESCENT") {
optimization_algorithm
= OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
} else {
throw new IllegalArgumentException(
"Not correct optimization algorithm");
}
if (args[3] == "RELU") {
activation_function = Activation.RELU;
} else if (args[3] == "SIGMOID") {
activation_function = Activation.SIGMOID;
} else if (args[3] == "TANH") {
activation_function = Activation.TANH;
} else {
throw new IllegalArgumentException(
"Not correct activation function");
}
LearningNeuralNetwork lnn
= new LearningNeuralNetwork("webshell_data.csv");
= new LearningNeuralNetwork(file_name,
learning_rate,
optimization_algorithm,
activation_function);
for (int neurons_number = 5; neurons_number < 50; neurons_number++) {
System.out.println("Number of neurons in hidden layer : "
+ neurons_number);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment