Skip to content
Snippets Groups Projects
Commit 79cfec61 authored by a.croix's avatar a.croix
Browse files

Add some logs

parent f2b8d5b7
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2293 failed
......@@ -59,7 +59,7 @@ public final class MainTest {
NeuralNetwork nn = new NeuralNetwork(nn_parameters);
TrainingDataset dataset = new TrainingDataset(data, expected);
List<TrainingDataset> folds = dataset.prepareFolds(10);
List<TrainingDataset> folds = dataset.prepareFolds(3);
HashMap<MultiLayerNetwork, Double> map_nn = nn.runKFold(folds, 10);
HashMap<AbstractSolution, Double> map_wt = trainer.runKFold(folds, 10);
......@@ -72,12 +72,12 @@ public final class MainTest {
}
System.out.println("Average AUC for Neural Network learning : "
+ nn_score / 10);
+ nn_score / 3);
for (Double d : map_wt.values()) {
wt_score = wt_score + d;
}
System.out.println("Average AUC for WOWA learning : " + wt_score / 10);
System.out.println("Average AUC for WOWA learning : " + wt_score / 3);
}
}
......@@ -265,11 +265,11 @@ public final class NeuralNetwork {
}
TrainingDataset dataset_increased = learning_fold.increaseTrueAlert(
increase_ratio);
System.out.println("Fold number : " + (i + 1));
MultiLayerNetwork nn = run(
dataset_increased.getData(),
dataset_increased.getExpected());
Double score = modelEvaluation(testingfold, nn);
map.put(nn, score);
}
return map;
......
package be.cylab.java.wowa.training;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
/**
......@@ -178,6 +182,7 @@ public class Trainer {
Double score = sol.computeAUC(
testing.getData(),
testing.getExpected());
System.out.println("Fold number : " + (i + 1) + "AUC : " + score);
map.put(sol, score);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment