Skip to content
Snippets Groups Projects
Commit 4370c214 authored by a.croix's avatar a.croix
Browse files

Merge branch 'k-fold' into 'master'

K fold

See merge request !3
parents 5a1c2d69 f49b04cd
No related branches found
No related tags found
1 merge request!3K fold
Pipeline #1766 passed
......@@ -20,7 +20,7 @@ test:mvn:jdk8:
image: maven:3.5.3-jdk-8
script:
- mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
- mvn exec:java -Dexec.mainClass="be.cylab.java.wowa.training.Example" -Dexec.args="100 60 10 110 RWS RANDOM" -Dmaven.repo.local=.m2
test:mvn:jdk9:
......@@ -28,7 +28,6 @@ test:mvn:jdk9:
image: maven:3.5.3-jdk-9
script:
- mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
- mvn exec:java -Dexec.mainClass="be.cylab.java.wowa.training.Example" -Dexec.args="100 60 10 110 RWS RANDOM" -Dmaven.repo.local=.m2
test:mvn:jdk10:
......@@ -36,7 +35,6 @@ test:mvn:jdk10:
image: maven:3.5.3-jdk-10
script:
- mvn clean verify -Dgpg.skip -Dmaven.repo.local=.m2
- mvn exec:java -Dexec.mainClass="be.cylab.java.wowa.training.Example" -Dexec.args="100 60 10 110 RWS RANDOM" -Dmaven.repo.local=.m2
##mvn:jdk11:
## image: maven:3.5.3-jdk-11
......
......@@ -7,9 +7,9 @@ import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Random;
import java.util.Date;
/**
......@@ -126,6 +126,7 @@ public abstract class AbstractSolution
}
/**
* Method to compute AUC with List as argument.
* @param data
* @param expected
* @return
......@@ -147,6 +148,22 @@ public abstract class AbstractSolution
return roc.computeAUC();
}
/**
* Method to compute AUC with filename as arugument.
* @param filename_data
* @param filename_expected
* @return
*/
public final double computeAUC(
final String filename_data,
final String filename_expected) {
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(filename_data);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(filename_expected);
return computeAUC(data, expected);
}
/**
* @param solution
* @return int
......
package be.cylab.java.wowa.training;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;
......@@ -46,7 +48,8 @@ public final class Example {
"Initialization must be RANDOM or QUASI_RANDOM");
}
String data_file = "./ressources/webshell_data.json";
String expected_file = "./ressources/webshell_expected.json";
Logger logger = Logger.getLogger(Trainer.class.getName());
logger.setLevel(Level.INFO);
StreamHandler handler = new StreamHandler(System.out,
......@@ -70,7 +73,16 @@ public final class Example {
logger.log(Level.INFO, "Execution time : "
+ (end_time - start_time) / 1000 + " seconds");
System.out.println("Run Cross validation");
HashMap<AbstractSolution, Double> solutions = trainer.runKFold(
data_file,
expected_file,
10,
10);
for (Map.Entry val : solutions.entrySet()) {
System.out.println(val);
}
}
......
......@@ -3,6 +3,7 @@ package be.cylab.java.wowa.training;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
......@@ -30,6 +31,7 @@ public class Trainer {
/**
* Run method with file as arguments.
*
* @param data_file_name
* @param expected_file_name
* @return
......@@ -87,6 +89,134 @@ public class Trainer {
return best_solution;
}
/**
* Method to perform a k-fold cross validation.
*
* @param data
* @param expected
* @param fold_number
* @return
*/
public final HashMap<AbstractSolution, Double> runKFold(
final List<List<Double>> data,
final List<Double> expected,
final int fold_number,
final int increase_ration_alert) {
TrainingDataset dataset = new TrainingDataset(data, expected);
List<TrainingDataset> folds = prepareFolds(dataset, fold_number);
HashMap<AbstractSolution, Double> map = new HashMap<>();
for (int i = 0; i < fold_number; i++) {
TrainingDataset testing = folds.get(i);
TrainingDataset learning = new TrainingDataset();
for (int j = 0; j < fold_number; j++) {
if (j != i) {
learning.addFoldInDataset(folds, j);
}
}
TrainingDataset dataset_increased = increaseTrueAlert(
learning.getData(),
learning.getExpected(),
increase_ration_alert);
AbstractSolution sol = run(
dataset_increased.getData(),
dataset_increased.getExpected());
Double score = sol.computeAUC(
testing.getData(),
testing.getExpected());
map.put(sol, score);
}
return map;
}
/**
* Method to perform a cross validation with filename as argument.
*
* @param data_file_name
* @param expected_file_name
* @param fold_number
* @return
*/
public final HashMap<AbstractSolution, Double> runKFold(
final String data_file_name,
final String expected_file_name,
final int fold_number,
final int increase_ratio_alert) {
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(data_file_name);
List<Double> expected
= Utils.convertJsonToExpectedForTrainer(expected_file_name);
return runKFold(data, expected, fold_number, increase_ratio_alert);
}
/**
* Method to increase the number of true alert in data.
* to increase penalty to do not detect a true alert
*
* @param data
* @param expected
* @param increase_ratio
*/
private TrainingDataset increaseTrueAlert(
final List<List<Double>> data,
final List<Double> expected,
final int increase_ratio) {
int data_size = expected.size();
for (int i = 0; i < data_size; i++) {
if (expected.get(i) == 1) {
for (int j = 0; j < increase_ratio; j++) {
expected.add(expected.get(i));
data.add(data.get(i));
}
}
}
return new TrainingDataset(data, expected);
}
/**
* Method to separate randomly the base dataset in X folds dataset.
*
* @param dataset
* @param fold_number
* @return
*/
private List<TrainingDataset> prepareFolds(
final TrainingDataset dataset,
final int fold_number) {
//List<List<Double>> data = dataset.getData();
List<Double> expected = dataset.getExpected();
List<TrainingDataset> fold_dataset = new ArrayList<>();
//Check if it is rounded !!!!
int element_number_in_fold = expected.size() / fold_number;
int alert_number
= (int) Math.floor(Utils.sumListElements(expected)
/ fold_number);
int no_alert_number = element_number_in_fold - alert_number;
for (int i = 0; i < fold_number; i++) {
TrainingDataset tmp = new TrainingDataset();
int alert_counter = 0;
int no_alert_counter = 0;
while (tmp.getLength() < (alert_number + no_alert_number)) {
int index = Utils.randomInteger(0, dataset.getLength() - 1);
if (dataset.getExpected().get(index) == 1
&& alert_counter < alert_number) {
tmp.addElementInDataset(dataset, index);
dataset.removeElementInDataset(index);
alert_counter++;
} else if (dataset.getExpected().get(index) == 0
&& no_alert_counter < no_alert_number) {
tmp.addElementInDataset(dataset, index);
dataset.removeElementInDataset(index);
no_alert_counter++;
}
}
fold_dataset.add(tmp);
}
return fold_dataset;
}
/**
* Find the best element in the population based on its fitness score.
*
......
package be.cylab.java.wowa.training;
import java.util.ArrayList;
import java.util.List;
/**
* Class to manage the training dataset.
*/
class TrainingDataset {
private List<List<Double>> data;
private List<Double> expected;
private int length;
/**
* Default constructor.
*/
TrainingDataset() {
this.data = new ArrayList<>();
this.expected = new ArrayList<>();
this.length = 0;
}
/**
* Constructor takes List as arguments.
*
* @param data
* @param expected
*/
TrainingDataset(
final List<List<Double>> data,
final List<Double> expected) {
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and expected list must have the same size");
}
this.data = data;
this.expected = expected;
this.length = data.size();
}
/**
* Getter for data list.
*
* @return
*/
List<List<Double>> getData() {
return data;
}
/**
* Setter for data list.
*
* @param data
*/
/**
* Getter for expected list.
*
* @return
*/
public List<Double> getExpected() {
return expected;
}
/**
* Length getter.
*
* @return
*/
int getLength() {
return this.length;
}
/**
* Copy an element from a dataset to another one.
*
* @param dataset
* @param index
* @return
*/
TrainingDataset addElementInDataset(
final TrainingDataset dataset,
final int index) {
this.data.add(dataset.getData().get(index));
this.expected.add(dataset.getExpected().get(index));
this.length++;
return this;
}
/**
* Remove an element from a dataset.
*
* @param index
*/
void removeElementInDataset(final int index) {
this.data.remove(index);
this.expected.remove(index);
this.length--;
}
/**
* Method to add a fold to a dataset (learning dataset creation).
*
* @param data_to_add
* @param index
*/
TrainingDataset addFoldInDataset(
final List<TrainingDataset> data_to_add,
final int index) {
this.data.addAll(data_to_add.get(index).data);
this.expected.addAll(data_to_add.get(index).expected);
this.length += data_to_add.get(index).getLength();
return this;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment