From 79cfec617a12242b1bff771f156d715843415313 Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Mon, 9 Sep 2019 15:47:09 +0200 Subject: [PATCH] Add some logs --- src/main/java/be/cylab/java/wowa/training/MainTest.java | 6 +++--- .../java/be/cylab/java/wowa/training/NeuralNetwork.java | 2 +- src/main/java/be/cylab/java/wowa/training/Trainer.java | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/MainTest.java b/src/main/java/be/cylab/java/wowa/training/MainTest.java index 16f3348..addb503 100644 --- a/src/main/java/be/cylab/java/wowa/training/MainTest.java +++ b/src/main/java/be/cylab/java/wowa/training/MainTest.java @@ -59,7 +59,7 @@ public final class MainTest { NeuralNetwork nn = new NeuralNetwork(nn_parameters); TrainingDataset dataset = new TrainingDataset(data, expected); - List<TrainingDataset> folds = dataset.prepareFolds(10); + List<TrainingDataset> folds = dataset.prepareFolds(3); HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10); HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10); @@ -72,12 +72,12 @@ public final class MainTest { } System.out.println("Average AUC for Neural Network learning : " - + nn_score / 10); + + nn_score / 3); for (Double d : map_wt.values()) { wt_score = wt_score + d; } - System.out.println("Average AUC for WOWA learning : " + wt_score / 10); + System.out.println("Average AUC for WOWA learning : " + wt_score / 3); } } diff --git a/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java index 6c3eb66..fc3d513 100644 --- a/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java +++ b/src/main/java/be/cylab/java/wowa/training/NeuralNetwork.java @@ -265,11 +265,11 @@ public final class NeuralNetwork { } TrainingDataset dataset_increased = learning_fold.increaseTrueAlert( increase_ratio); + System.out.println("Fold number : " + (i + 1)); MultiLayerNetwork nn = run( dataset_increased.getData(), dataset_increased.getExpected()); Double score = modelEvaluation(testingfold, nn); - map.put(nn, score); } return map; diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 0cece52..fda997a 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -1,7 +1,11 @@ package be.cylab.java.wowa.training; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Random; import java.util.logging.Level; /** @@ -178,6 +182,7 @@ public class Trainer { Double score = sol.computeAUC( testing.getData(), testing.getExpected()); + System.out.println("Fold number : " + (i + 1) + "AUC : " + score); map.put(sol, score); } -- GitLab