diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index 46ee6ca648949f67859870f97439302e8a7e9ee6..d1055c0ba01da1379640af927d2298bbb5edc1e3 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -1,7 +1,8 @@ package be.cylab.java.wowa.training; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.logging.SimpleFormatter; @@ -73,14 +74,14 @@ public final class Example { + (end_time - start_time) / 1000 + " seconds"); System.out.println("Run Cross validation"); - List<Double> solutions = trainer.runKFold( + HashMap<AbstractSolution, Double> solutions = trainer.runKFold( data_file, expected_file, + 10, 10); - for (int i = 0; i < solutions.size(); i++) { - System.out.println("AUC Value for fold " + i + ": " - + solutions.get(i)); + for (Map.Entry val : solutions.entrySet()) { + System.out.println(val); } } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index a15548c2c3e5ed8907676b17e32f9f17eea30ced..91fc9d39d59f66e9b20460c877d62656fdf52883 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -3,6 +3,7 @@ package be.cylab.java.wowa.training; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.logging.Level; @@ -96,13 +97,14 @@ public class Trainer { * @param fold_number * @return */ - public final List<Double> runKFold( + public final HashMap<AbstractSolution, Double> runKFold( final List<List<Double>> data, final List<Double> expected, - final int fold_number) { + final int fold_number, + final int increase_ration_alert) { TrainingDataset dataset = new TrainingDataset(data, expected); List<TrainingDataset> folds = prepareFolds(dataset, fold_number); - List<Double> auc = new ArrayList<>(); + HashMap<AbstractSolution, Double> map = new HashMap<>(); for (int i = 0; i < fold_number; i++) { TrainingDataset testing = folds.get(i); TrainingDataset learning = new TrainingDataset(); @@ -114,35 +116,38 @@ public class Trainer { TrainingDataset dataset_increased = increaseTrueAlert( learning.getData(), learning.getExpected(), - fold_number); + increase_ration_alert); AbstractSolution sol = run( dataset_increased.getData(), dataset_increased.getExpected()); Double score = sol.computeAUC( testing.getData(), testing.getExpected()); - auc.add(score); + + map.put(sol, score); } - return auc; + return map; } /** * Method to perform a cross validation with filename as argument. + * * @param data_file_name * @param expected_file_name * @param fold_number * @return */ - public final List<Double> runKFold( + public final HashMap<AbstractSolution, Double> runKFold( final String data_file_name, final String expected_file_name, - final int fold_number) { + final int fold_number, + final int increase_ratio_alert) { List<List<Double>> data = Utils.convertJsonToDataForTrainer(data_file_name); List<Double> expected = Utils.convertJsonToExpectedForTrainer(expected_file_name); - return runKFold(data, expected, fold_number); + return runKFold(data, expected, fold_number, increase_ratio_alert); } /** @@ -171,6 +176,7 @@ public class Trainer { /** * Method to separate randomly the base dataset in X folds dataset. + * * @param dataset * @param fold_number * @return