diff --git a/src/main/java/be/cylab/java/wowa/training/MainTest.java b/src/main/java/be/cylab/java/wowa/training/MainTest.java index addb503e522f2715acca4d65e2cd47eea1c4bc8a..10cb9fc2b29a9bd5906cfe0c6dab0146edc2bbeb 100644 --- a/src/main/java/be/cylab/java/wowa/training/MainTest.java +++ b/src/main/java/be/cylab/java/wowa/training/MainTest.java @@ -59,9 +59,10 @@ public final class MainTest { NeuralNetwork nn = new NeuralNetwork(nn_parameters); TrainingDataset dataset = new TrainingDataset(data, expected); - List<TrainingDataset> folds = dataset.prepareFolds(3); - + List<TrainingDataset> folds = dataset.prepareFolds(10); + System.out.println("Neural Network learning"); HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10); + System.out.println("Wowa training"); HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10); @@ -72,12 +73,12 @@ public final class MainTest { } System.out.println("Average AUC for Neural Network learning : " - + nn_score / 3); + + nn_score / 10); for (Double d : map_wt.values()) { wt_score = wt_score + d; } - System.out.println("Average AUC for WOWA learning : " + wt_score / 3); + System.out.println("Average AUC for WOWA learning : " + wt_score / 10); } }