diff --git a/java-wowa-training.iml b/java-wowa-training.iml index 08c09f61200f2f2388e08bc4ac406e511612c382..e65012bc30b00fe8f9ce599d0af7a8d7f858c371 100644 --- a/java-wowa-training.iml +++ b/java-wowa-training.iml @@ -42,13 +42,15 @@ <orderEntry type="library" name="Maven: info.debatty:java-aggregation:0.5" level="project" /> <orderEntry type="library" name="Maven: com.owlike:genson:1.6" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.8.1" level="project" /> - <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.3" level="project" /> + <orderEntry type="library" name="Maven: be.cylab:java-roc:0.0.4" level="project" /> <orderEntry type="library" name="Maven: com.opencsv:opencsv:4.5" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.3" level="project" /> <orderEntry type="library" name="Maven: commons-beanutils:commons-beanutils:1.9.3" level="project" /> <orderEntry type="library" name="Maven: commons-logging:commons-logging:1.2" level="project" /> <orderEntry type="library" name="Maven: commons-collections:commons-collections:3.2.2" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.2" level="project" /> + <orderEntry type="library" name="Maven: org.knowm.xchart:xchart:3.5.4" level="project" /> + <orderEntry type="library" name="Maven: de.erichseifert.vectorgraphics2d:VectorGraphics2D:0.13" level="project" /> <orderEntry type="library" name="Maven: org.datavec:datavec-api:0.9.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.8.1" level="project" /> <orderEntry type="library" name="Maven: org.apache.commons:commons-math3:3.3" level="project" /> diff --git a/src/main/java/be/cylab/java/wowa/training/HyperParameters.java b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java new file mode 100644 index 0000000000000000000000000000000000000000..9915e6e3f926f7bb17b4cfdb4290aefbf5c0be8a --- /dev/null +++ b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java @@ -0,0 +1,140 @@ +package be.cylab.java.wowa.training; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.nd4j.linalg.activations.Activation; + +/** + * Parameters of the NeuralNetwork class. + */ +public class HyperParameters { + private int neurons_number; + private double learning_rate; + private OptimizationAlgorithm algorithm; + private Activation activation_function; + private double percent_test_train; + + /** + * Default constructor. + * @param neurons_number + * @param learning_rate + * @param algorithm + * @param activation_function + * @param percent_test_train + */ + public HyperParameters( + final int neurons_number, + final double learning_rate, + final OptimizationAlgorithm algorithm, + final Activation activation_function, + final double percent_test_train) { + setNeuronsNumber(neurons_number); + setLearningRate(learning_rate); + setAlgorithm(algorithm); + setActivationFunction(activation_function); + setPercentTestTrain(percent_test_train); + } + + /** + * Constructor without percent_train_test value. + * @param neurons_number + * @param learning_rate + * @param algorithm + * @param activation_function + */ + public HyperParameters( + final int neurons_number, + final double learning_rate, + final OptimizationAlgorithm algorithm, + final Activation activation_function) { + this(neurons_number, learning_rate, algorithm, + activation_function, 99); + } + + /** + * Getter for neuron_number. + * @return + */ + public final int getNeuronsNumber() { + return neurons_number; + } + + /** + * Getter for learning rate. + * @return + */ + public final double getLearningRate() { + return learning_rate; + } + + /** + * Getter for backpropagation algorithm. + * @return + */ + public final OptimizationAlgorithm getAlgorithm() { + return algorithm; + } + + /** + * Getter for activation function. + * @return + */ + public final Activation getActivationFunction() { + return activation_function; + } + + /** + * Getter for percent_test_train. + * @return + */ + public final double getPercentTestTrain() { + return percent_test_train; + } + + /** + * @param neurons_number + */ + public final void setNeuronsNumber(final int neurons_number) { + if (neurons_number < 5 || neurons_number > 80) { + throw new IllegalArgumentException( + "Neuron number must be between 5 and 80"); + } + this.neurons_number = neurons_number; + } + + /** + * @param learning_rate + */ + public final void setLearningRate(final double learning_rate) { + if (learning_rate <= 0.0 || learning_rate >= 1.0) { + throw new IllegalArgumentException( + "Learning rate must be between 0 and 1"); + } + this.learning_rate = learning_rate; + } + + /** + * @param algorithm + */ + public final void setAlgorithm(final OptimizationAlgorithm algorithm) { + this.algorithm = algorithm; + } + + /** + * @param activation_function + */ + public final void setActivationFunction( + final Activation activation_function) { + this.activation_function = activation_function; + } + + /** + * @param percent_test_train + */ + public final void setPercentTestTrain(final double percent_test_train) { + if (percent_test_train <= 0 || percent_test_train >= 100) { + throw new IllegalArgumentException( + "Percentage of train must be between 0 and 100"); + } + this.percent_test_train = percent_test_train; + } +} diff --git a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java deleted file mode 100644 index 7328913d5b67cd2845395633e1ffff9b547b04eb..0000000000000000000000000000000000000000 --- a/src/main/java/be/cylab/java/wowa/training/LearningNeuralNetwork.java +++ /dev/null @@ -1,130 +0,0 @@ -package be.cylab.java.wowa.training; - -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.util.ClassPathResource; -import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.ROC; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.SplitTestAndTrain; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.IOException; - -/** - * Class to test neural network learning. - */ - -public final class LearningNeuralNetwork { - - /** - * Class count. - */ - public static final int CLASSES_COUNT = 2; - /** - * Features count. - */ - public static final int FEATURES_COUNT = 5; - - - private DataSet training_data; - private DataSet testing_data; - private double learning_rate; - private OptimizationAlgorithm algorithm; - private Activation activation_function; - - - /** - * Constructor. - * - * @param file_name - */ - public LearningNeuralNetwork(final String file_name, - final double learning_rate, - final OptimizationAlgorithm algorithm, - final Activation activation_function - ) { - try (RecordReader record_reader = new CSVRecordReader(0, ',')) { - record_reader.initialize(new FileSplit( - new ClassPathResource(file_name).getFile() - )); - DataSetIterator iterator = new RecordReaderDataSetIterator( - record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT); - DataSet all_data = iterator.next(); - all_data.shuffle(); - DataNormalization normalizer = new NormalizerStandardize(); - normalizer.fit(all_data); - normalizer.transform(all_data); - - SplitTestAndTrain test_and_train = all_data.splitTestAndTrain(0.90); - this.training_data = test_and_train.getTrain(); - this.testing_data = test_and_train.getTest(); - - } catch (IOException e) { - e.printStackTrace(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - this.learning_rate = learning_rate; - this.algorithm = algorithm; - this.activation_function = activation_function; - } - - /** - * Method to build a simple neural network, train and test it. - * - * @param neurons_number - */ - public void learning(final int neurons_number) { - - MultiLayerConfiguration configuration - = new NeuralNetConfiguration.Builder() - .iterations(1000) - .activation(this.activation_function) - .optimizationAlgo( - this.algorithm) - .weightInit(WeightInit.XAVIER) - .learningRate(this.learning_rate) - .regularization(true).l2(0.0001) - .list() - .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT) - .nOut(neurons_number).build()) - .layer(1, new DenseLayer.Builder().nIn(neurons_number) - .nOut(neurons_number).build()) - .layer(2, new OutputLayer.Builder(LossFunctions - .LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX) - .nIn(neurons_number) - .nOut(CLASSES_COUNT).build()) - .backprop(true).pretrain(false) - .build(); - - MultiLayerNetwork model = new MultiLayerNetwork(configuration); - model.init(); - model.fit(this.training_data); - - INDArray output - = model.output((this.testing_data.getFeatureMatrix())); - - Evaluation eval = new Evaluation(CLASSES_COUNT); - eval.eval(this.testing_data.getLabels(), output); - System.out.println(eval.stats()); - ROC roc = new ROC(CLASSES_COUNT); - roc.eval(this.testing_data.getLabels(), output); - System.out.println(roc.stats()); - } -} diff --git a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java index d8a5b18a6ac2110b44ccec1bc938a8e6f14e6400..f9f0fa0bcc4b71699b15e8b73d87f1cfcd4e313a 100644 --- a/src/main/java/be/cylab/java/wowa/training/MainDL4J.java +++ b/src/main/java/be/cylab/java/wowa/training/MainDL4J.java @@ -1,8 +1,15 @@ package be.cylab.java.wowa.training; import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.activations.Activation; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; + /** * Class for learn in neuronal network. */ @@ -20,20 +27,24 @@ public final class MainDL4J { * @param args */ public static void main(final String[] args) { - - String file_name = args[0]; - double learning_rate = Double.parseDouble(args[1]); + int begin_neuron_number = Integer.parseInt(args[0]); + int neuron_number_end = Integer.parseInt(args[1]); + String data_file = args[5]; + String expected_file = args[6]; + double learning_rate = Double.parseDouble(args[2]); OptimizationAlgorithm optimization_algorithm; Activation activation_function; + int fold_number = 10; + int increase_ratio = 10; - if (args[2].matches("CONJUGATE_GRADIENT")) { + if (args[3].matches("CONJUGATE_GRADIENT")) { optimization_algorithm = OptimizationAlgorithm.CONJUGATE_GRADIENT; - } else if (args[2].matches("LBFGS")) { + } else if (args[3].matches("LBFGS")) { optimization_algorithm = OptimizationAlgorithm.LBFGS; - } else if (args[2].matches("LINE_GRADIENT_DESCENT")) { + } else if (args[3].matches("LINE_GRADIENT_DESCENT")) { optimization_algorithm = OptimizationAlgorithm.LINE_GRADIENT_DESCENT; - } else if (args[2].matches("STOCHASTIC_GRADIENT_DESCENT")) { + } else if (args[3].matches("STOCHASTIC_GRADIENT_DESCENT")) { optimization_algorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; } else { @@ -41,27 +52,48 @@ public final class MainDL4J { "Not correct optimization algorithm"); } - if (args[3].matches("RELU")) { + if (args[4].matches("RELU")) { activation_function = Activation.RELU; - } else if (args[3].matches("SIGMOID")) { + } else if (args[4].matches("SIGMOID")) { activation_function = Activation.SIGMOID; - } else if (args[3].matches("TANH")) { + } else if (args[4].matches("TANH")) { activation_function = Activation.TANH; } else { throw new IllegalArgumentException( "Not correct activation function"); } - LearningNeuralNetwork lnn - = new LearningNeuralNetwork(file_name, - learning_rate, - optimization_algorithm, - activation_function); - for (int neurons_number = 5; neurons_number < 50; neurons_number++) { - System.out.println("Number of neurons in hidden layer : " - + neurons_number); - lnn.learning(neurons_number); + for (int neuron_number = begin_neuron_number; + neuron_number < neuron_number_end; neuron_number++) { + double average_auc = 0; + System.out.println("Neuron number : " + neuron_number); + HyperParameters parameters + = new HyperParameters(neuron_number, learning_rate, + optimization_algorithm, activation_function); + NeuralNetwork nn = new NeuralNetwork(parameters); + + HashMap<MultiLayerNetwork, Double> map = nn.runKFold( + data_file, + expected_file, + fold_number, + increase_ratio); + for (Double d : map.values()) { + average_auc = average_auc + d; + } + try (OutputStreamWriter writer = new OutputStreamWriter( + new FileOutputStream("Synthesis_average_AUC.txt", true), + StandardCharsets.UTF_8)) { + writer.write("Neuron number : " + neuron_number + + " Learning rate : " + learning_rate + + " Algorithm : " + args[3] + + " Activation function : " + args[4] + + " Average AUC = " + + (average_auc / fold_number) + "\n"); + } catch (IOException e) { + e.printStackTrace(); + } } + } } diff --git a/src/main/java/be/cylab/java/wowa/training/MainTest.java b/src/main/java/be/cylab/java/wowa/training/MainTest.java new file mode 100644 index 0000000000000000000000000000000000000000..8a2f482f8d8191f2204adb06d67c54d7050597b5 --- /dev/null +++ b/src/main/java/be/cylab/java/wowa/training/MainTest.java @@ -0,0 +1,91 @@ +package be.cylab.java.wowa.training; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.linalg.activations.Activation; + +import java.util.HashMap; +import java.util.List; + +/** + * Main class for test and compare efficiency of nn and wt. + */ +public final class MainTest { + + private MainTest() { + + } + + /** + * @param args + */ + public static void main(final String[] args) { + int population_size = 100; + int crossover_rate = 60; + int muatation_rate = 15; + int max_generation_number = 120; + int selection_method + = TrainerParameters.SELECTION_METHOD_RWS; + int generation_population_method + = TrainerParameters.POPULATION_INITIALIZATION_RANDOM; + String data_file = "ressources/webshell_data.json"; + String expected_file = "ressources/webshell_expected.json"; + List<List<Double>> data + = Utils.convertJsonToDataForTrainer(data_file); + List<Double> expected + = Utils.convertJsonToExpectedForTrainer(expected_file); + + TrainerParameters wt_parameters = new TrainerParameters( + null, + population_size, + crossover_rate, + muatation_rate, + max_generation_number, + selection_method, + generation_population_method); + Trainer trainer = new Trainer(wt_parameters, + new SolutionDistance(5)); + + int neurons_number = 20; + double learning_rate = 0.01; + OptimizationAlgorithm algo = OptimizationAlgorithm.CONJUGATE_GRADIENT; + Activation activation_function = Activation.TANH; + HyperParameters nn_parameters = new HyperParameters( + neurons_number, + learning_rate, + algo, + activation_function + ); + NeuralNetwork nn = new NeuralNetwork(nn_parameters); + + TrainingDataset dataset = new TrainingDataset(data, expected); + List<TrainingDataset> folds = dataset.prepareFolds(10); + + long start_time = System.currentTimeMillis(); + System.out.println("Wowa training"); + HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10); + long end_time = System.currentTimeMillis(); + System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds"); + + start_time = System.currentTimeMillis(); + System.out.println("Neural Network learning"); + HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10); + end_time = System.currentTimeMillis(); + System.out.println("Execution time : " + (end_time - start_time) / 1000 + " seconds"); + + double nn_score = 0.0; + double wt_score = 0.0; + for (Double d : map_nn.values()) { + nn_score = nn_score + d; + } + + System.out.println("Average AUC for Neural Network learning : " + + nn_score / 10); + + for (Double d : map_wt.values()) { + wt_score = wt_score + d; + } + + System.out.println("Average AUC for WOWA learning : " + wt_score / 10); + } +} diff --git a/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java new file mode 100644 index 0000000000000000000000000000000000000000..fc3d5135930e9a771aa157527d74deb356283ee2 --- /dev/null +++ b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java @@ -0,0 +1,278 @@ +package be.cylab.java.wowa.training; + +import org.datavec.api.records.reader.RecordReader; +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; +import org.datavec.api.split.FileSplit; +import org.datavec.api.util.ClassPathResource; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; +import org.deeplearning4j.eval.Evaluation; +import org.deeplearning4j.eval.ROC; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.SplitTestAndTrain; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; + +/** + * Class to test neural network learning. + */ + +public final class NeuralNetwork { + + /** + * Class count. + */ + public static final int CLASSES_COUNT = 2; + /** + * Features count. + */ + public static final int FEATURES_COUNT = 5; + + + private HyperParameters parameters; + + + /** + * Default constructor. + * + * @param parameters + */ + public NeuralNetwork(final HyperParameters parameters + ) { + + this.parameters = parameters; + } + + /** + * @param learning + * @return + */ + MultiLayerNetwork learning( + final DataSet learning) { + + int neurons_number = parameters.getNeuronsNumber(); + MultiLayerConfiguration configuration + = new NeuralNetConfiguration.Builder() + .iterations(1000) + .activation(parameters.getActivationFunction()) + .optimizationAlgo( + parameters.getAlgorithm()) + .weightInit(WeightInit.XAVIER) + .learningRate(parameters.getLearningRate()) + .regularization(true).l2(0.0001) + .list() + .layer(0, new DenseLayer.Builder() + .nIn(learning.getFeatures().columns()) + .nOut(neurons_number).build()) + .layer(1, new DenseLayer.Builder().nIn(neurons_number) + .nOut(neurons_number).build()) + .layer(2, new OutputLayer.Builder(LossFunctions + .LossFunction.NEGATIVELOGLIKELIHOOD) + .activation(Activation.SOFTMAX) + .nIn(neurons_number) + .nOut(learning.getLabels().columns()).build()) + .backprop(true).pretrain(false) + .build(); + + MultiLayerNetwork model = new MultiLayerNetwork(configuration); + model.init(); + model.fit(learning); + + return model; + } + + /** + * @param data + * @param expected + * @return + */ + public MultiLayerNetwork run( + final List<List<Double>> data, + final List<Double> expected) { + + DataSet all_data = Utils.createDataSet(data, expected); + //DataSet training_data + // = prepareDataSetForTrainingAndTesting(all_data).getTrain(); + return this.learning(all_data); + + } + + /** + * @param data_filename + * @param expected_filename + * @return + */ + public MultiLayerNetwork run( + final String data_filename, + final String expected_filename) { + List<List<Double>> data + = Utils.convertJsonToDataForTrainer(data_filename); + List<Double> expected + = Utils.convertJsonToExpectedForTrainer(expected_filename); + return this.run(data, expected); + } + + /** + * @param filename + * @return + */ + public MultiLayerNetwork runCSV(final String filename) { + try (RecordReader record_reader = new CSVRecordReader(0, ',')) { + record_reader.initialize(new FileSplit( + new ClassPathResource(filename).getFile() + )); + DataSetIterator iterator = new RecordReaderDataSetIterator( + record_reader, 12468, FEATURES_COUNT, CLASSES_COUNT); + DataSet all_data = iterator.next(); + //DataSet training_data + // = prepareDataSetForTrainingAndTesting(all_data).getTrain(); + return learning(all_data); + + } catch (IOException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + return null; + } + + /** + * Method to evaluate the performance of the model. + * + * @param testing + * @param network + */ + public Double modelEvaluation( + final DataSet testing, + final MultiLayerNetwork network) { + INDArray output + = network.output((testing.getFeatureMatrix())); + + Evaluation eval = new Evaluation(testing.getLabels().columns()); + eval.eval(testing.getLabels(), output); + System.out.println(eval.stats()); + ROC roc = new ROC(testing.getLabels().columns()); + roc.eval(testing.getLabels(), output); + System.out.println(roc.stats()); + return roc.calculateAUC(); + } + + /** + * @param testing + * @param network + * @return + */ + public Double modelEvaluation( + final TrainingDataset testing, + final MultiLayerNetwork network) { + DataSet test + = Utils.createDataSet(testing.getData(), testing.getExpected()); + return modelEvaluation(test, network); + } + + /** + * @param data + * @return + */ + SplitTestAndTrain prepareDataSetForTrainingAndTesting(final DataSet data) { + data.shuffle(); + DataNormalization normalizer = new NormalizerStandardize(); + normalizer.fit(data); + normalizer.transform(data); + + return data.splitTestAndTrain(parameters.getPercentTestTrain()); + } + + /** + * @param data + * @param expected + * @param fold_number + * @param increase_ratio + * @return + */ + public HashMap<MultiLayerNetwork, Double> runKFold( + final List<List<Double>> data, + final List<Double> expected, + final int fold_number, + final int increase_ratio) { + TrainingDataset dataset = new TrainingDataset(data, expected); + List<TrainingDataset> folds = dataset.prepareFolds(fold_number); + HashMap<MultiLayerNetwork, Double> map = new HashMap<>(); + for (int i = 0; i < fold_number; i++) { + TrainingDataset testingfold = folds.get(i); + TrainingDataset learning_fold = new TrainingDataset(); + for (int j = 0; j < fold_number; j++) { + if (j != i) { + learning_fold.addFoldInDataset(folds, i); + } + } + TrainingDataset dataset_increased = learning_fold.increaseTrueAlert( + increase_ratio); + MultiLayerNetwork nn = run( + dataset_increased.getData(), + dataset_increased.getExpected()); + Double score = modelEvaluation(testingfold, nn); + + map.put(nn, score); + } + return map; + } + + /** + * @param filename_data + * @param filename_expected + * @param fold_number + * @param increase_ratio + * @return + */ + public HashMap<MultiLayerNetwork, Double> runKFold( + final String filename_data, + final String filename_expected, + final int fold_number, + final int increase_ratio) { + List<List<Double>> data + = Utils.convertJsonToDataForTrainer(filename_data); + List<Double> expected + = Utils.convertJsonToExpectedForTrainer(filename_expected); + + return runKFold(data, expected, fold_number, increase_ratio); + } + + HashMap<MultiLayerNetwork, Double> runKFold( + final List<TrainingDataset> prepared_folds, + final int increase_ratio) { + HashMap<MultiLayerNetwork, Double> map = new HashMap<>(); + for (int i = 0; i < prepared_folds.size(); i++) { + TrainingDataset testingfold = prepared_folds.get(i); + TrainingDataset learning_fold = new TrainingDataset(); + for (int j = 0; j < prepared_folds.size(); j++) { + if (j != i) { + learning_fold.addFoldInDataset(prepared_folds, i); + } + } + TrainingDataset dataset_increased = learning_fold.increaseTrueAlert( + increase_ratio); + System.out.println("Fold number : " + (i + 1)); + MultiLayerNetwork nn = run( + dataset_increased.getData(), + dataset_increased.getExpected()); + Double score = modelEvaluation(testingfold, nn); + map.put(nn, score); + } + return map; + } + +} diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index f42121513011caf37d326314aa884fe5de399301..fda997a226203456ae52646316344c90eed31f35 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -104,7 +104,7 @@ public class Trainer { final int fold_number, final int increase_ration_alert) { TrainingDataset dataset = new TrainingDataset(data, expected); - List<TrainingDataset> folds = prepareFolds(dataset, fold_number); + List<TrainingDataset> folds = dataset.prepareFolds(fold_number); HashMap<AbstractSolution, Double> map = new HashMap<>(); for (int i = 0; i < fold_number; i++) { TrainingDataset testing = folds.get(i); @@ -114,9 +114,7 @@ public class Trainer { learning.addFoldInDataset(folds, j); } } - TrainingDataset dataset_increased = increaseTrueAlert( - learning.getData(), - learning.getExpected(), + TrainingDataset dataset_increased = learning.increaseTrueAlert( increase_ration_alert); AbstractSolution sol = run( dataset_increased.getData(), @@ -160,83 +158,38 @@ public class Trainer { } /** - * Method to increase the number of true alert in data. - * to increase penalty to do not detect a true alert - * - * @param data - * @param expected + * @param prepared_folds * @param increase_ratio + * @return */ - final TrainingDataset increaseTrueAlert( - final List<List<Double>> data, - final List<Double> expected, + HashMap<AbstractSolution, Double> runKFold( + final List<TrainingDataset> prepared_folds, final int increase_ratio) { - int data_size = expected.size(); - for (int i = 0; i < data_size; i++) { - if (expected.get(i) == 1) { - for (int j = 0; j < increase_ratio - 1; j++) { - expected.add(expected.get(i)); - data.add(data.get(i)); + HashMap<AbstractSolution, Double> map = new HashMap<>(); + for (int i = 0; i < prepared_folds.size(); i++) { + TrainingDataset testing = prepared_folds.get(i); + TrainingDataset learning = new TrainingDataset(); + for (int j = 0; j < prepared_folds.size(); j++) { + if (j != i) { + learning.addFoldInDataset(prepared_folds, j); } } - } - return new TrainingDataset(data, expected); - } - - /** - * Method to separate randomly the base dataset in X folds dataset. - * - * @param dataset - * @param fold_number - * @return - */ - final List<TrainingDataset> prepareFolds( - final TrainingDataset dataset, - final int fold_number) { - //List<List<Double>> data = dataset.getData(); - List<Double> expected = dataset.getExpected(); - List<TrainingDataset> fold_dataset = new ArrayList<>(); - //Check if it is rounded !!!! - int alert_number - = (int) Math.floor(Utils.sumListElements(expected) - / fold_number); - int no_alert_number = (int) (expected.size() - - Utils.sumListElements(expected)) / fold_number; + TrainingDataset dataset_increased = learning.increaseTrueAlert( + increase_ratio); + AbstractSolution sol = run( + dataset_increased.getData(), + dataset_increased.getExpected()); + Double score = sol.computeAUC( + testing.getData(), + testing.getExpected()); + System.out.println("Fold number : " + (i + 1) + "AUC : " + score); - for (int i = 0; i < fold_number; i++) { - TrainingDataset tmp = new TrainingDataset(); - int alert_counter = 0; - int no_alert_counter = 0; - while (tmp.getLength() < (alert_number + no_alert_number)) { - int index = Utils.randomInteger(0, dataset.getLength() - 1); - if (dataset.getExpected().get(index) == 1 - && alert_counter < alert_number) { - tmp.addElementInDataset(dataset, index); - dataset.removeElementInDataset(index); - alert_counter++; - } else if (dataset.getExpected().get(index) == 0 - && no_alert_counter < no_alert_number) { - tmp.addElementInDataset(dataset, index); - dataset.removeElementInDataset(index); - no_alert_counter++; - } - } - fold_dataset.add(tmp); - } - int fold_counter = 0; - while (dataset.getLength() > 0) { - int i = Utils.randomInteger(0, dataset.getLength() - 1); - fold_dataset.get(fold_counter).addElementInDataset(dataset, i); - dataset.removeElementInDataset(i); - if (fold_counter == fold_dataset.size() - 1) { - fold_counter = 0; - } else { - fold_counter++; - } + map.put(sol, score); } - return fold_dataset; + return map; } + /** * Find the best element in the population based on its fitness score. * @@ -390,6 +343,7 @@ public class Trainer { /** * Method used only for tests !! * This method generates random number (tos) by using a seed ! + * * @param solutions * @param selected_elements * @param count diff --git a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java index 2c75c85e1acbdf0ed1c9d5f888cbe615971e6d46..6d7ff6bb826666d911886309d534f36f32acc4b5 100644 --- a/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java +++ b/src/main/java/be/cylab/java/wowa/training/TrainingDataset.java @@ -117,4 +117,76 @@ class TrainingDataset { this.length += data_to_add.get(index).getLength(); return this; } + + /** + * Method to separate randomly the base dataset in X folds dataset. + * + * @param fold_number + * @return + */ + final List<TrainingDataset> prepareFolds( + final int fold_number) { + //List<List<Double>> data = dataset.getData(); + List<Double> expected = this.getExpected(); + List<TrainingDataset> fold_dataset = new ArrayList<>(); + //Check if it is rounded !!!! + int alert_number + = (int) Math.floor(Utils.sumListElements(expected) + / fold_number); + int no_alert_number = (int) (expected.size() + - Utils.sumListElements(expected)) / fold_number; + + for (int i = 0; i < fold_number; i++) { + TrainingDataset tmp = new TrainingDataset(); + int alert_counter = 0; + int no_alert_counter = 0; + while (tmp.getLength() < (alert_number + no_alert_number)) { + int index = Utils.randomInteger(0, this.getLength() - 1); + if (this.getExpected().get(index) == 1 + && alert_counter < alert_number) { + tmp.addElementInDataset(this, index); + this.removeElementInDataset(index); + alert_counter++; + } else if (this.getExpected().get(index) == 0 + && no_alert_counter < no_alert_number) { + tmp.addElementInDataset(this, index); + this.removeElementInDataset(index); + no_alert_counter++; + } + } + fold_dataset.add(tmp); + } + int fold_counter = 0; + while (this.getLength() > 0) { + int i = Utils.randomInteger(0, this.getLength() - 1); + fold_dataset.get(fold_counter).addElementInDataset(this, i); + this.removeElementInDataset(i); + if (fold_counter == fold_dataset.size() - 1) { + fold_counter = 0; + } else { + fold_counter++; + } + } + return fold_dataset; + } + + /** + * Method to increase the number of true alert in data. + * to increase penalty to do not detect a true alert + * + * @param increase_ratio + */ + final TrainingDataset increaseTrueAlert( + final int increase_ratio) { + int data_size = expected.size(); + for (int i = 0; i < data_size; i++) { + if (expected.get(i) == 1) { + for (int j = 0; j < increase_ratio - 1; j++) { + expected.add(expected.get(i)); + data.add(data.get(i)); + } + } + } + return new TrainingDataset(data, expected); + } } diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index e8174f4be434425d68cfa1251cd77d97b3f8d493..1bcf4a6794e10731ac75884f57cc5773a990470d 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -4,6 +4,9 @@ import com.owlike.genson.GenericType; import com.owlike.genson.Genson; import info.debatty.java.aggregation.WOWA; import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; import java.io.FileOutputStream; import java.io.IOException; @@ -569,4 +572,31 @@ final class Utils { e.printStackTrace(); } } + + static DataSet createDataSet( + final List<List<Double>> data, + final List<Double> expected) { + if (data.size() != expected.size()) { + throw new IllegalArgumentException( + "Data and Expected must have the same size"); + } + double[][] data_array = new double[data.size()][data.get(0).size()]; + double[][] expected_array = new double[expected.size()][2]; + for (int i = 0; i < data.size(); i++) { + if (expected.get(i) == 1.0) { + expected_array[i][0] = 1.0; + expected_array[i][1] = 0.0; + } else { + expected_array[i][0] = 0.0; + expected_array[i][1] = 1.0; + } + + for (int j = 0; j < data.get(i).size(); j++) { + data_array[i][j] = data.get(i).get(j); + } + } + INDArray data_ind = Nd4j.create(data_array); + INDArray expected_ind = Nd4j.create(expected_array); + return new DataSet(data_ind, expected_ind); + } } diff --git a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java index fd5d815f5260415d4c216c6a49134e66f358bf64..41bcd04e790770c42e29660142594b4e4f709fdf 100644 --- a/src/test/java/be/cylab/java/wowa/training/TrainerTest.java +++ b/src/test/java/be/cylab/java/wowa/training/TrainerTest.java @@ -22,66 +22,6 @@ class TrainerTest { void testRun() { } - @Test - void testIncreaseTrueAlert() { - List<List<Double>> data = generateData(100, 5); - int original_data_set_size = data.size(); - List<Double> expected = generateExpectedBinaryClassification(100); - int increase_ratio = 10; - int number_of_true_alert = (int)(double)Utils.sumListElements(expected); - int number_of_no_alert = original_data_set_size - number_of_true_alert; - TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio); - - //Check if the length of the dataset us correct - //increase_ration * number_of_true_alert + expected.size() - assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength()); - - //Check the number of true_alert in the dataset - //increase_ration * number_of_true_alert + number_of_true_alert - assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected())); - - //Check if each rue alert elements are present (increase_ratio time) in he dataset - //Here, we check if each true alert element is present 10 more times than the original one - for (int i = 0; i < expected.size(); i++) { - if (ds.getExpected().get(i) == 1.0) { - int cnt = 0; - for (int j = 0; j < ds.getLength(); j++) { - if (ds.getExpected().get(i) == ds.getExpected().get(j)) { - cnt++; - } - } - assertEquals(increase_ratio, cnt); - } - } - - } - - @Test - void testPrepareFolds() { - int number_of_elements = 100; - List<List<Double>> data = generateData(number_of_elements,5); - List<Double> expected = generateExpectedBinaryClassification(number_of_elements); - int increase_ratio = 3; - int fold_number = 10; - int number_of_alert = (int)(double)Utils.sumListElements(expected); - int number_of_no_alert = number_of_elements - number_of_alert; - TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio); - List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number); - assertEquals(fold_number, folds.size()); - for (int i = 0; i < folds.size(); i++) { - assertTrue(folds.get(i).getLength() - == (number_of_alert * increase_ratio + number_of_no_alert) - / fold_number || folds.get(i).getLength() - == 1 + (number_of_alert * increase_ratio + number_of_no_alert) - / fold_number); - assertTrue(Utils.sumListElements(folds.get(i).getExpected()) - == ((number_of_alert * increase_ratio) / fold_number) || - Utils.sumListElements(folds.get(i).getExpected()) - == 1 + ((number_of_alert * increase_ratio) / fold_number) - || Utils.sumListElements(folds.get(i).getExpected()) - == 2 + ((number_of_alert * increase_ratio) / fold_number)); - } - } @Test void testFindBestSolution() { diff --git a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java index a2eda8a8608d96c4f82146a3a909da0c961ce66f..94a5e59b57e4fc2a725de5eb412e7b424baedbf3 100644 --- a/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java +++ b/src/test/java/be/cylab/java/wowa/training/TrainingDatasetTest.java @@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; +import java.util.Random; import static org.junit.jupiter.api.Assertions.*; @@ -108,4 +109,100 @@ class TrainingDatasetTest { } } } + + @Test + void testIncreaseTrueAlert() { + List<List<Double>> data = generateData(100, 5); + int original_data_set_size = data.size(); + List<Double> expected = generateExpectedBinaryClassification(100); + int increase_ratio = 10; + int number_of_true_alert = (int)(double)Utils.sumListElements(expected); + int number_of_no_alert = original_data_set_size - number_of_true_alert; + TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio); + + //Check if the length of the dataset us correct + //increase_ration * number_of_true_alert + expected.size() + assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength()); + + //Check the number of true_alert in the dataset + //increase_ration * number_of_true_alert + number_of_true_alert + assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected())); + + //Check if each rue alert elements are present (increase_ratio time) in he dataset + //Here, we check if each true alert element is present 10 more times than the original one + for (int i = 0; i < expected.size(); i++) { + if (ds.getExpected().get(i) == 1.0) { + int cnt = 0; + for (int j = 0; j < ds.getLength(); j++) { + if (ds.getExpected().get(i) == ds.getExpected().get(j)) { + cnt++; + } + } + assertEquals(increase_ratio, cnt); + } + } + + } + + @Test + void testPrepareFolds() { + int number_of_elements = 100; + List<List<Double>> data = generateData(number_of_elements,5); + List<Double> expected = generateExpectedBinaryClassification(number_of_elements); + int increase_ratio = 3; + int fold_number = 10; + int number_of_alert = (int)(double)Utils.sumListElements(expected); + int number_of_no_alert = number_of_elements - number_of_alert; + TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio); + List<TrainingDataset> folds = ds.prepareFolds(fold_number); + assertEquals(fold_number, folds.size()); + for (int i = 0; i < folds.size(); i++) { + assertTrue(folds.get(i).getLength() + == (number_of_alert * increase_ratio + number_of_no_alert) + / fold_number || folds.get(i).getLength() + == 1 + (number_of_alert * increase_ratio + number_of_no_alert) + / fold_number); + assertTrue(Utils.sumListElements(folds.get(i).getExpected()) + == ((number_of_alert * increase_ratio) / fold_number) || + Utils.sumListElements(folds.get(i).getExpected()) + == 1 + ((number_of_alert * increase_ratio) / fold_number) + || Utils.sumListElements(folds.get(i).getExpected()) + == 2 + ((number_of_alert * increase_ratio) / fold_number)); + } + } + + static List<List<Double>> generateData(final int size, final int weight_number) { + Random rnd = new Random(5489); + List<List<Double>> data = new ArrayList<>(); + for (int i = 0; i < size; i++) { + List<Double> vector = new ArrayList<>(); + for (int j = 0; j < weight_number; j++) { + vector.add(rnd.nextDouble()); + } + data.add(vector); + } + return data; + } + + static List<Double> generateExpected(final int size) { + Random rnd = new Random(5768); + List<Double> expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + expected.add(rnd.nextDouble()); + } + return expected; + } + + static List<Double> generateExpectedBinaryClassification(final int size) { + Random rnd = new Random(5768); + List<Double> expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + if (rnd.nextDouble() <= 0.5) { + expected.add(new Double(0.0)); + } else { + expected.add(new Double(1.0)); + } + } + return expected; + } } \ No newline at end of file