Skip to content
Snippets Groups Projects
Commit 44cbeb73 authored by a.croix's avatar a.croix
Browse files

Merge run-data-set-iterator in Neural Network branch

parents 2a4297f7 42150423
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2331 failed
Showing with 788 additions and 152 deletions
...@@ -42,13 +42,15 @@ ...@@ -42,13 +42,15 @@
<orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" /> <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" />
<orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" /> <orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" />
<orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" /> <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" />
<orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" /> <orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" />
<orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" /> <orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" />
<orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" /> <orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" />
<orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" /> <orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" />
<orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" />
<orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" />
<orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" /> <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" />
......
package be.cylab.java.wowa.training;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.nd4j.linalg.activations.Activation;
/**
* Parameters of the NeuralNetwork class.
*/
public class HyperParameters {
private int neurons_number;
private double learning_rate;
private OptimizationAlgorithm algorithm;
private Activation activation_function;
private double percent_test_train;
/**
* Default constructor.
* @param neurons_number
* @param learning_rate
* @param algorithm
* @param activation_function
* @param percent_test_train
*/
public HyperParameters(
final int neurons_number,
final double learning_rate,
final OptimizationAlgorithm algorithm,
final Activation activation_function,
final double percent_test_train) {
setNeuronsNumber(neurons_number);
setLearningRate(learning_rate);
setAlgorithm(algorithm);
setActivationFunction(activation_function);
setPercentTestTrain(percent_test_train);
}
/**
* Constructor without percent_train_test value.
* @param neurons_number
* @param learning_rate
* @param algorithm
* @param activation_function
*/
public HyperParameters(
final int neurons_number,
final double learning_rate,
final OptimizationAlgorithm algorithm,
final Activation activation_function) {
this(neurons_number, learning_rate, algorithm,
activation_function, 99);
}
/**
* Getter for neuron_number.
* @return
*/
public final int getNeuronsNumber() {
return neurons_number;
}
/**
* Getter for learning rate.
* @return
*/
public final double getLearningRate() {
return learning_rate;
}
/**
* Getter for backpropagation algorithm.
* @return
*/
public final OptimizationAlgorithm getAlgorithm() {
return algorithm;
}
/**
* Getter for activation function.
* @return
*/
public final Activation getActivationFunction() {
return activation_function;
}
/**
* Getter for percent_test_train.
* @return
*/
public final double getPercentTestTrain() {
return percent_test_train;
}
/**
* @param neurons_number
*/
public final void setNeuronsNumber(final int neurons_number) {
if (neurons_number < 5 || neurons_number > 80) {
throw new IllegalArgumentException(
"Neuron number must be between 5 and 80");
}
this.neurons_number = neurons_number;
}
/**
* @param learning_rate
*/
public final void setLearningRate(final double learning_rate) {
if (learning_rate <= 0.0 || learning_rate >= 1.0) {
throw new IllegalArgumentException(
"Learning rate must be between 0 and 1");
}
this.learning_rate = learning_rate;
}
/**
* @param algorithm
*/
public final void setAlgorithm(final OptimizationAlgorithm algorithm) {
this.algorithm = algorithm;
}
/**
* @param activation_function
*/
public final void setActivationFunction(
final Activation activation_function) {
this.activation_function = activation_function;
}
/**
* @param percent_test_train
*/
public final void setPercentTestTrain(final double percent_test_train) {
if (percent_test_train <= 0 || percent_test_train >= 100) {
throw new IllegalArgumentException(
"Percentage of train must be between 0 and 100");
}
this.percent_test_train = percent_test_train;
}
}
package be.cylab.java.wowa.training; package be.cylab.java.wowa.training;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
/** /**
* Class for learn in neuronal network. * Class for learn in neuronal network.
*/ */
...@@ -20,20 +27,24 @@ public final class MainDL4J { ...@@ -20,20 +27,24 @@ public final class MainDL4J {
* @param args * @param args
*/ */
public static void main(final String[] args) { public static void main(final String[] args) {
int begin_neuron_number = Integer.parseInt(args[0]);
String file_name = args[0]; int neuron_number_end = Integer.parseInt(args[1]);
double learning_rate = Double.parseDouble(args[1]); String data_file = args[5];
String expected_file = args[6];
double learning_rate = Double.parseDouble(args[2]);
OptimizationAlgorithm optimization_algorithm; OptimizationAlgorithm optimization_algorithm;
Activation activation_function; Activation activation_function;
int fold_number = 10;
int increase_ratio = 10;
if (args[2].matches("CONJUGATE_GRADIENT")) { if (args[3].matches("CONJUGATE_GRADIENT")) {
optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT; optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT;
} else if (args[2].matches("LBFGS")) { } else if (args[3].matches("LBFGS")) {
optimization_algorithm = OptimizationAlgorithm.LBFGS; optimization_algorithm = OptimizationAlgorithm.LBFGS;
} else if (args[2].matches("LINE_GRADIENT_DESCENT")) { } else if (args[3].matches("LINE_GRADIENT_DESCENT")) {
optimization_algorithm optimization_algorithm
= OptimizationAlgorithm.LINE_GRADIENT_DESCENT; = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
} else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) { } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) {
optimization_algorithm optimization_algorithm
= OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
} else { } else {
...@@ -41,27 +52,48 @@ public final class MainDL4J { ...@@ -41,27 +52,48 @@ public final class MainDL4J {
"Not correct optimization algorithm"); "Not correct optimization algorithm");
} }
if (args[3].matches("RELU")) { if (args[4].matches("RELU")) {
activation_function = Activation.RELU; activation_function = Activation.RELU;
} else if (args[3].matches("SIGMOID")) { } else if (args[4].matches("SIGMOID")) {
activation_function = Activation.SIGMOID; activation_function = Activation.SIGMOID;
} else if (args[3].matches("TANH")) { } else if (args[4].matches("TANH")) {
activation_function = Activation.TANH; activation_function = Activation.TANH;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Not correct activation function"); "Not correct activation function");
} }
LearningNeuralNetwork lnn for (int neuron_number = begin_neuron_number;
= new LearningNeuralNetwork(file_name, neuron_number < neuron_number_end; neuron_number++) {
learning_rate, double average_auc = 0;
optimization_algorithm, System.out.println("Neuron number : " + neuron_number);
activation_function); HyperParameters parameters
for (int neurons_number = 5; neurons_number < 50; neurons_number++) { = new HyperParameters(neuron_number, learning_rate,
System.out.println("Number of neurons in hidden layer : " optimization_algorithm, activation_function);
+ neurons_number); NeuralNetwork nn = new NeuralNetwork(parameters);
lnn.learning(neurons_number);
HashMap<MultiLayerNetwork, Double> map = nn.runKFold(
data_file,
expected_file,
fold_number,
increase_ratio);
for (Double d : map.values()) {
average_auc = average_auc + d;
}
try (OutputStreamWriter writer = new OutputStreamWriter(
new FileOutputStream("Synthesis_average_AUC.txt", true),
StandardCharsets.UTF_8)) {
writer.write("Neuron number : " + neuron_number
+ " Learning rate : " + learning_rate
+ " Algorithm : " + args[3]
+ " Activation function : " + args[4]
+ " Average AUC = "
+ (average_auc / fold_number) + "\n");
} catch (IOException e) {
e.printStackTrace();
}
} }
} }
} }
package be.cylab.java.wowa.training;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import java.util.HashMap;
import java.util.List;
/**
* Main class for test and compare efficiency of nn and wt.
*/
public final class MainTest {
private MainTest() {
}
/**
* @param args
*/
public static void main(final String[] args) {
int population_size = 100;
int crossover_rate = 60;
int muatation_rate = 15;
int max_generation_number = 120;
int selection_method
= TrainerParameters.SELECTION_METHOD_RWS;
int generation_population_method
= TrainerParameters.POPULATION_INITIALIZATION_RANDOM;
String data_file = "ressources/webshell_data.json";
String expected_file = "ressources/webshell_expected.json";
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(data_file);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(expected_file);
TrainerParameters wt_parameters = new TrainerParameters(
null,
population_size,
crossover_rate,
muatation_rate,
max_generation_number,
selection_method,
generation_population_method);
Trainer trainer = new Trainer(wt_parameters,
new SolutionDistance(5));
int neurons_number = 20;
double learning_rate = 0.01;
OptimizationAlgorithm algo = OptimizationAlgorithm.CONJUGATE_GRADIENT;
Activation activation_function = Activation.TANH;
HyperParameters nn_parameters = new HyperParameters(
neurons_number,
learning_rate,
algo,
activation_function
);
NeuralNetwork nn = new NeuralNetwork(nn_parameters);
TrainingDataset dataset = new TrainingDataset(data, expected);
List<TrainingDataset> folds = dataset.prepareFolds(10);
long start_time = System.currentTimeMillis();
System.out.println("Wowa training");
HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10);
long end_time = System.currentTimeMillis();
System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds");
start_time = System.currentTimeMillis();
System.out.println("Neural Network learning");
HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10);
end_time = System.currentTimeMillis();
System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds");
double nn_score = 0.0;
double wt_score = 0.0;
for (Double d : map_nn.values()) {
nn_score = nn_score + d;
}
System.out.println("Average AUC for Neural Network learning : "
+ nn_score / 10);
for (Double d : map_wt.values()) {
wt_score = wt_score + d;
}
System.out.println("Average AUC for WOWA learning : " + wt_score / 10);
}
}
...@@ -7,7 +7,6 @@ import org.datavec.api.util.ClassPathResource; ...@@ -7,7 +7,6 @@ import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC; import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
...@@ -24,12 +23,14 @@ import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; ...@@ -24,12 +23,14 @@ import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.List;
/** /**
* Class to test neural network learning. * Class to test neural network learning.
*/ */
public final class LearningNeuralNetwork { public final class NeuralNetwork {
/** /**
* Class count. * Class count.
...@@ -41,67 +42,40 @@ public final class LearningNeuralNetwork { ...@@ -41,67 +42,40 @@ public final class LearningNeuralNetwork {
public static final int FEATURES_COUNT = 5; public static final int FEATURES_COUNT = 5;
private DataSet training_data; private HyperParameters parameters;
private DataSet testing_data;
private double learning_rate;
private OptimizationAlgorithm algorithm;
private Activation activation_function;
/** /**
* Constructor. * Default constructor.
* *
* @param file_name * @param parameters
*/ */
public LearningNeuralNetwork(final String file_name, public NeuralNetwork(final HyperParameters parameters
final double learning_rate,
final OptimizationAlgorithm algorithm,
final Activation activation_function
) { ) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource(file_name).getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
all_data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(all_data);
normalizer.transform(all_data);
SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.90);
this.training_data = test_and_train.getTrain();
this.testing_data = test_and_train.getTest();
} catch (IOException e) { this.parameters = parameters;
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
this.learning_rate = learning_rate;
this.algorithm = algorithm;
this.activation_function = activation_function;
} }
/** /**
* Method to build a simple neural network, train and test it. * @param learning
* * @return
* @param neurons_number
*/ */
public void learning(final int neurons_number) { MultiLayerNetwork learning(
final DataSet learning) {
int neurons_number = parameters.getNeuronsNumber();
MultiLayerConfiguration configuration MultiLayerConfiguration configuration
= new NeuralNetConfiguration.Builder() = new NeuralNetConfiguration.Builder()
.iterations(1000) .iterations(1000)
.activation(this.activation_function) .activation(parameters.getActivationFunction())
.optimizationAlgo( .optimizationAlgo(
this.algorithm) parameters.getAlgorithm())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.learningRate(this.learning_rate) .learningRate(parameters.getLearningRate())
.regularization(true).l2(0.0001) .regularization(true).l2(0.0001)
.list() .list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT) .layer(0, new DenseLayer.Builder()
.nIn(learning.getFeatures().columns())
.nOut(neurons_number).build()) .nOut(neurons_number).build())
.layer(1, new DenseLayer.Builder().nIn(neurons_number) .layer(1, new DenseLayer.Builder().nIn(neurons_number)
.nOut(neurons_number).build()) .nOut(neurons_number).build())
...@@ -109,22 +83,196 @@ public final class LearningNeuralNetwork { ...@@ -109,22 +83,196 @@ public final class LearningNeuralNetwork {
.LossFunction.NEGATIVELOGLIKELIHOOD) .LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX)
.nIn(neurons_number) .nIn(neurons_number)
.nOut(CLASSES_COUNT).build()) .nOut(learning.getLabels().columns()).build())
.backprop(true).pretrain(false) .backprop(true).pretrain(false)
.build(); .build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration); MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init(); model.init();
model.fit(this.training_data); model.fit(learning);
return model;
}
/**
* @param data
* @param expected
* @return
*/
public MultiLayerNetwork run(
final List<List<Double>> data,
final List<Double> expected) {
DataSet all_data = Utils.createDataSet(data, expected);
//DataSet training_data
// = prepareDataSetForTrainingAndTesting(all_data).getTrain();
return this.learning(all_data);
}
/**
* @param data_filename
* @param expected_filename
* @return
*/
public MultiLayerNetwork run(
final String data_filename,
final String expected_filename) {
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(data_filename);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(expected_filename);
return this.run(data, expected);
}
/**
* @param filename
* @return
*/
public MultiLayerNetwork runCSV(final String filename) {
try (RecordReader record_reader = new CSVRecordReader(0, ',')) {
record_reader.initialize(new FileSplit(
new ClassPathResource(filename).getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(
record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT);
DataSet all_data = iterator.next();
//DataSet training_data
// = prepareDataSetForTrainingAndTesting(all_data).getTrain();
return learning(all_data);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
}
/**
* Method to evaluate the performance of the model.
*
* @param testing
* @param network
*/
public Double modelEvaluation(
final DataSet testing,
final MultiLayerNetwork network) {
INDArray output INDArray output
= model.output((this.testing_data.getFeatureMatrix())); = network.output((testing.getFeatureMatrix()));
Evaluation eval = new Evaluation(CLASSES_COUNT); Evaluation eval = new Evaluation(testing.getLabels().columns());
eval.eval(this.testing_data.getLabels(), output); eval.eval(testing.getLabels(), output);
System.out.println(eval.stats()); System.out.println(eval.stats());
ROC roc = new ROC(CLASSES_COUNT); ROC roc = new ROC(testing.getLabels().columns());
roc.eval(this.testing_data.getLabels(), output); roc.eval(testing.getLabels(), output);
System.out.println(roc.stats()); System.out.println(roc.stats());
return roc.calculateAUC();
}
/**
* @param testing
* @param network
* @return
*/
public Double modelEvaluation(
final TrainingDataset testing,
final MultiLayerNetwork network) {
DataSet test
= Utils.createDataSet(testing.getData(), testing.getExpected());
return modelEvaluation(test, network);
}
/**
* @param data
* @return
*/
SplitTestAndTrain prepareDataSetForTrainingAndTesting(final DataSet data) {
data.shuffle();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(data);
normalizer.transform(data);
return data.splitTestAndTrain(parameters.getPercentTestTrain());
} }
/**
* @param data
* @param expected
* @param fold_number
* @param increase_ratio
* @return
*/
public HashMap<MultiLayerNetwork, Double> runKFold(
final List<List<Double>> data,
final List<Double> expected,
final int fold_number,
final int increase_ratio) {
TrainingDataset dataset = new TrainingDataset(data, expected);
List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
for (int i = 0; i < fold_number; i++) {
TrainingDataset testingfold = folds.get(i);
TrainingDataset learning_fold = new TrainingDataset();
for (int j = 0; j < fold_number; j++) {
if (j != i) {
learning_fold.addFoldInDataset(folds, i);
}
}
TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
increase_ratio);
MultiLayerNetwork nn = run(
dataset_increased.getData(),
dataset_increased.getExpected());
Double score = modelEvaluation(testingfold, nn);
map.put(nn, score);
}
return map;
}
/**
* @param filename_data
* @param filename_expected
* @param fold_number
* @param increase_ratio
* @return
*/
public HashMap<MultiLayerNetwork, Double> runKFold(
final String filename_data,
final String filename_expected,
final int fold_number,
final int increase_ratio) {
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(filename_data);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(filename_expected);
return runKFold(data, expected, fold_number, increase_ratio);
}
HashMap<MultiLayerNetwork, Double> runKFold(
final List<TrainingDataset> prepared_folds,
final int increase_ratio) {
HashMap<MultiLayerNetwork, Double> map = new HashMap<>();
for (int i = 0; i < prepared_folds.size(); i++) {
TrainingDataset testingfold = prepared_folds.get(i);
TrainingDataset learning_fold = new TrainingDataset();
for (int j = 0; j < prepared_folds.size(); j++) {
if (j != i) {
learning_fold.addFoldInDataset(prepared_folds, i);
}
}
TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
increase_ratio);
System.out.println("Fold number : " + (i + 1));
MultiLayerNetwork nn = run(
dataset_increased.getData(),
dataset_increased.getExpected());
Double score = modelEvaluation(testingfold, nn);
map.put(nn, score);
}
return map;
}
} }
...@@ -104,7 +104,7 @@ public class Trainer { ...@@ -104,7 +104,7 @@ public class Trainer {
final int fold_number, final int fold_number,
final int increase_ration_alert) { final int increase_ration_alert) {
TrainingDataset dataset = new TrainingDataset(data, expected); TrainingDataset dataset = new TrainingDataset(data, expected);
List<TrainingDataset> folds = prepareFolds(dataset, fold_number); List<TrainingDataset> folds = dataset.prepareFolds(fold_number);
HashMap<AbstractSolution, Double> map = new HashMap<>(); HashMap<AbstractSolution, Double> map = new HashMap<>();
for (int i = 0; i < fold_number; i++) { for (int i = 0; i < fold_number; i++) {
TrainingDataset testing = folds.get(i); TrainingDataset testing = folds.get(i);
...@@ -114,9 +114,7 @@ public class Trainer { ...@@ -114,9 +114,7 @@ public class Trainer {
learning.addFoldInDataset(folds, j); learning.addFoldInDataset(folds, j);
} }
} }
TrainingDataset dataset_increased = increaseTrueAlert( TrainingDataset dataset_increased = learning.increaseTrueAlert(
learning.getData(),
learning.getExpected(),
increase_ration_alert); increase_ration_alert);
AbstractSolution sol = run( AbstractSolution sol = run(
dataset_increased.getData(), dataset_increased.getData(),
...@@ -160,83 +158,38 @@ public class Trainer { ...@@ -160,83 +158,38 @@ public class Trainer {
} }
/** /**
* Method to increase the number of true alert in data. * @param prepared_folds
* to increase penalty to do not detect a true alert
*
* @param data
* @param expected
* @param increase_ratio * @param increase_ratio
* @return
*/ */
final TrainingDataset increaseTrueAlert( HashMap<AbstractSolution, Double> runKFold(
final List<List<Double>> data, final List<TrainingDataset> prepared_folds,
final List<Double> expected,
final int increase_ratio) { final int increase_ratio) {
int data_size = expected.size(); HashMap<AbstractSolution, Double> map = new HashMap<>();
for (int i = 0; i < data_size; i++) { for (int i = 0; i < prepared_folds.size(); i++) {
if (expected.get(i) == 1) { TrainingDataset testing = prepared_folds.get(i);
for (int j = 0; j < increase_ratio - 1; j++) { TrainingDataset learning = new TrainingDataset();
expected.add(expected.get(i)); for (int j = 0; j < prepared_folds.size(); j++) {
data.add(data.get(i)); if (j != i) {
learning.addFoldInDataset(prepared_folds, j);
} }
} }
} TrainingDataset dataset_increased = learning.increaseTrueAlert(
return new TrainingDataset(data, expected); increase_ratio);
} AbstractSolution sol = run(
dataset_increased.getData(),
/** dataset_increased.getExpected());
* Method to separate randomly the base dataset in X folds dataset. Double score = sol.computeAUC(
* testing.getData(),
* @param dataset testing.getExpected());
* @param fold_number System.out.println("Fold number : " + (i + 1) + "AUC : " + score);
* @return
*/
final List<TrainingDataset> prepareFolds(
final TrainingDataset dataset,
final int fold_number) {
//List<List<Double>> data = dataset.getData();
List<Double> expected = dataset.getExpected();
List<TrainingDataset> fold_dataset = new ArrayList<>();
//Check if it is rounded !!!!
int alert_number
= (int) Math.floor(Utils.sumListElements(expected)
/ fold_number);
int no_alert_number = (int) (expected.size()
- Utils.sumListElements(expected)) / fold_number;
for (int i = 0; i < fold_number; i++) { map.put(sol, score);
TrainingDataset tmp = new TrainingDataset();
int alert_counter = 0;
int no_alert_counter = 0;
while (tmp.getLength() < (alert_number + no_alert_number)) {
int index = Utils.randomInteger(0, dataset.getLength() - 1);
if (dataset.getExpected().get(index) == 1
&& alert_counter < alert_number) {
tmp.addElementInDataset(dataset, index);
dataset.removeElementInDataset(index);
alert_counter++;
} else if (dataset.getExpected().get(index) == 0
&& no_alert_counter < no_alert_number) {
tmp.addElementInDataset(dataset, index);
dataset.removeElementInDataset(index);
no_alert_counter++;
}
}
fold_dataset.add(tmp);
}
int fold_counter = 0;
while (dataset.getLength() > 0) {
int i = Utils.randomInteger(0, dataset.getLength() - 1);
fold_dataset.get(fold_counter).addElementInDataset(dataset, i);
dataset.removeElementInDataset(i);
if (fold_counter == fold_dataset.size() - 1) {
fold_counter = 0;
} else {
fold_counter++;
}
} }
return fold_dataset; return map;
} }
/** /**
* Find the best element in the population based on its fitness score. * Find the best element in the population based on its fitness score.
* *
...@@ -390,6 +343,7 @@ public class Trainer { ...@@ -390,6 +343,7 @@ public class Trainer {
/** /**
* Method used only for tests !! * Method used only for tests !!
* This method generates random number (tos) by using a seed ! * This method generates random number (tos) by using a seed !
*
* @param solutions * @param solutions
* @param selected_elements * @param selected_elements
* @param count * @param count
......
...@@ -117,4 +117,76 @@ class TrainingDataset { ...@@ -117,4 +117,76 @@ class TrainingDataset {
this.length += data_to_add.get(index).getLength(); this.length += data_to_add.get(index).getLength();
return this; return this;
} }
/**
* Method to separate randomly the base dataset in X folds dataset.
*
* @param fold_number
* @return
*/
final List<TrainingDataset> prepareFolds(
final int fold_number) {
//List<List<Double>> data = dataset.getData();
List<Double> expected = this.getExpected();
List<TrainingDataset> fold_dataset = new ArrayList<>();
//Check if it is rounded !!!!
int alert_number
= (int) Math.floor(Utils.sumListElements(expected)
/ fold_number);
int no_alert_number = (int) (expected.size()
- Utils.sumListElements(expected)) / fold_number;
for (int i = 0; i < fold_number; i++) {
TrainingDataset tmp = new TrainingDataset();
int alert_counter = 0;
int no_alert_counter = 0;
while (tmp.getLength() < (alert_number + no_alert_number)) {
int index = Utils.randomInteger(0, this.getLength() - 1);
if (this.getExpected().get(index) == 1
&& alert_counter < alert_number) {
tmp.addElementInDataset(this, index);
this.removeElementInDataset(index);
alert_counter++;
} else if (this.getExpected().get(index) == 0
&& no_alert_counter < no_alert_number) {
tmp.addElementInDataset(this, index);
this.removeElementInDataset(index);
no_alert_counter++;
}
}
fold_dataset.add(tmp);
}
int fold_counter = 0;
while (this.getLength() > 0) {
int i = Utils.randomInteger(0, this.getLength() - 1);
fold_dataset.get(fold_counter).addElementInDataset(this, i);
this.removeElementInDataset(i);
if (fold_counter == fold_dataset.size() - 1) {
fold_counter = 0;
} else {
fold_counter++;
}
}
return fold_dataset;
}
/**
* Method to increase the number of true alert in data.
* to increase penalty to do not detect a true alert
*
* @param increase_ratio
*/
final TrainingDataset increaseTrueAlert(
final int increase_ratio) {
int data_size = expected.size();
for (int i = 0; i < data_size; i++) {
if (expected.get(i) == 1) {
for (int j = 0; j < increase_ratio - 1; j++) {
expected.add(expected.get(i));
data.add(data.get(i));
}
}
}
return new TrainingDataset(data, expected);
}
} }
...@@ -4,6 +4,9 @@ import com.owlike.genson.GenericType; ...@@ -4,6 +4,9 @@ import com.owlike.genson.GenericType;
import com.owlike.genson.Genson; import com.owlike.genson.Genson;
import info.debatty.java.aggregation.WOWA; import info.debatty.java.aggregation.WOWA;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
...@@ -569,4 +572,31 @@ final class Utils { ...@@ -569,4 +572,31 @@ final class Utils {
e.printStackTrace(); e.printStackTrace();
} }
} }
static DataSet createDataSet(
final List<List<Double>> data,
final List<Double> expected) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and Expected must have the same size");
}
double[][] data_array = new double[data.size()][data.get(0).size()];
double[][] expected_array = new double[expected.size()][2];
for (int i = 0; i < data.size(); i++) {
if (expected.get(i) == 1.0) {
expected_array[i][0] = 1.0;
expected_array[i][1] = 0.0;
} else {
expected_array[i][0] = 0.0;
expected_array[i][1] = 1.0;
}
for (int j = 0; j < data.get(i).size(); j++) {
data_array[i][j] = data.get(i).get(j);
}
}
INDArray data_ind = Nd4j.create(data_array);
INDArray expected_ind = Nd4j.create(expected_array);
return new DataSet(data_ind, expected_ind);
}
} }
...@@ -22,66 +22,6 @@ class TrainerTest { ...@@ -22,66 +22,6 @@ class TrainerTest {
void testRun() { void testRun() {
} }
@Test
void testIncreaseTrueAlert() {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 10;
int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = original_data_set_size - number_of_true_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = 0; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ratio, cnt);
}
}
}
@Test
void testPrepareFolds() {
int number_of_elements = 100;
List<List<Double>> data = generateData(number_of_elements,5);
List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
int increase_ratio = 3;
int fold_number = 10;
int number_of_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = number_of_elements - number_of_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
assertEquals(fold_number, folds.size());
for (int i = 0; i < folds.size(); i++) {
assertTrue(folds.get(i).getLength()
== (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number || folds.get(i).getLength()
== 1 + (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number);
assertTrue(Utils.sumListElements(folds.get(i).getExpected())
== ((number_of_alert * increase_ratio) / fold_number) ||
Utils.sumListElements(folds.get(i).getExpected())
== 1 + ((number_of_alert * increase_ratio) / fold_number)
|| Utils.sumListElements(folds.get(i).getExpected())
== 2 + ((number_of_alert * increase_ratio) / fold_number));
}
}
@Test @Test
void testFindBestSolution() { void testFindBestSolution() {
......
...@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; ...@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
...@@ -108,4 +109,100 @@ class TrainingDatasetTest { ...@@ -108,4 +109,100 @@ class TrainingDatasetTest {
} }
} }
} }
@Test
void testIncreaseTrueAlert() {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 10;
int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = original_data_set_size - number_of_true_alert;
TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = 0; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ratio, cnt);
}
}
}
@Test
void testPrepareFolds() {
int number_of_elements = 100;
List<List<Double>> data = generateData(number_of_elements,5);
List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
int increase_ratio = 3;
int fold_number = 10;
int number_of_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = number_of_elements - number_of_alert;
TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
List<TrainingDataset> folds = ds.prepareFolds(fold_number);
assertEquals(fold_number, folds.size());
for (int i = 0; i < folds.size(); i++) {
assertTrue(folds.get(i).getLength()
== (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number || folds.get(i).getLength()
== 1 + (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number);
assertTrue(Utils.sumListElements(folds.get(i).getExpected())
== ((number_of_alert * increase_ratio) / fold_number) ||
Utils.sumListElements(folds.get(i).getExpected())
== 1 + ((number_of_alert * increase_ratio) / fold_number)
|| Utils.sumListElements(folds.get(i).getExpected())
== 2 + ((number_of_alert * increase_ratio) / fold_number));
}
}
static List<List<Double>> generateData(final int size, final int weight_number) {
Random rnd = new Random(5489);
List<List<Double>> data = new ArrayList<>();
for (int i = 0; i < size; i++) {
List<Double> vector = new ArrayList<>();
for (int j = 0; j < weight_number; j++) {
vector.add(rnd.nextDouble());
}
data.add(vector);
}
return data;
}
static List<Double> generateExpected(final int size) {
Random rnd = new Random(5768);
List<Double> expected = new ArrayList<>();
for (int i = 0; i < size; i++) {
expected.add(rnd.nextDouble());
}
return expected;
}
static List<Double> generateExpectedBinaryClassification(final int size) {
Random rnd = new Random(5768);
List<Double> expected = new ArrayList<>();
for (int i = 0; i < size; i++) {
if (rnd.nextDouble() <= 0.5) {
expected.add(new Double(0.0));
} else {
expected.add(new Double(1.0));
}
}
return expected;
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment