Skip to content
Snippets Groups Projects
Commit b2d62b05 authored by a.croix's avatar a.croix
Browse files

Update tests to match with new project structure. Solve issue #8

parent 5a08bc20
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2264 failed
...@@ -169,6 +169,11 @@ public final class NeuralNetwork { ...@@ -169,6 +169,11 @@ public final class NeuralNetwork {
return roc.calculateAUC(); return roc.calculateAUC();
} }
/**
* @param testing
* @param network
* @return
*/
public Double modelEvaluation( public Double modelEvaluation(
final TrainingDataset testing, final TrainingDataset testing,
final MultiLayerNetwork network) { final MultiLayerNetwork network) {
......
...@@ -22,66 +22,6 @@ class TrainerTest { ...@@ -22,66 +22,6 @@ class TrainerTest {
void testRun() { void testRun() {
} }
@Test
void testIncreaseTrueAlert() {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 10;
int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = original_data_set_size - number_of_true_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = 0; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ratio, cnt);
}
}
}
@Test
void testPrepareFolds() {
int number_of_elements = 100;
List<List<Double>> data = generateData(number_of_elements,5);
List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
int increase_ratio = 3;
int fold_number = 10;
int number_of_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = number_of_elements - number_of_alert;
TrainingDataset ds = trainer.increaseTrueAlert(data, expected, increase_ratio);
List<TrainingDataset> folds = trainer.prepareFolds(ds, fold_number);
assertEquals(fold_number, folds.size());
for (int i = 0; i < folds.size(); i++) {
assertTrue(folds.get(i).getLength()
== (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number || folds.get(i).getLength()
== 1 + (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number);
assertTrue(Utils.sumListElements(folds.get(i).getExpected())
== ((number_of_alert * increase_ratio) / fold_number) ||
Utils.sumListElements(folds.get(i).getExpected())
== 1 + ((number_of_alert * increase_ratio) / fold_number)
|| Utils.sumListElements(folds.get(i).getExpected())
== 2 + ((number_of_alert * increase_ratio) / fold_number));
}
}
@Test @Test
void testFindBestSolution() { void testFindBestSolution() {
......
...@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; ...@@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
...@@ -108,4 +109,100 @@ class TrainingDatasetTest { ...@@ -108,4 +109,100 @@ class TrainingDatasetTest {
} }
} }
} }
@Test
void testIncreaseTrueAlert() {
List<List<Double>> data = generateData(100, 5);
int original_data_set_size = data.size();
List<Double> expected = generateExpectedBinaryClassification(100);
int increase_ratio = 10;
int number_of_true_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = original_data_set_size - number_of_true_alert;
TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
//Check if the length of the dataset us correct
//increase_ration * number_of_true_alert + expected.size()
assertEquals(number_of_true_alert * increase_ratio + number_of_no_alert, ds.getLength());
//Check the number of true_alert in the dataset
//increase_ration * number_of_true_alert + number_of_true_alert
assertEquals(number_of_true_alert * increase_ratio, (double)Utils.sumListElements(ds.getExpected()));
//Check if each rue alert elements are present (increase_ratio time) in he dataset
//Here, we check if each true alert element is present 10 more times than the original one
for (int i = 0; i < expected.size(); i++) {
if (ds.getExpected().get(i) == 1.0) {
int cnt = 0;
for (int j = 0; j < ds.getLength(); j++) {
if (ds.getExpected().get(i) == ds.getExpected().get(j)) {
cnt++;
}
}
assertEquals(increase_ratio, cnt);
}
}
}
@Test
void testPrepareFolds() {
int number_of_elements = 100;
List<List<Double>> data = generateData(number_of_elements,5);
List<Double> expected = generateExpectedBinaryClassification(number_of_elements);
int increase_ratio = 3;
int fold_number = 10;
int number_of_alert = (int)(double)Utils.sumListElements(expected);
int number_of_no_alert = number_of_elements - number_of_alert;
TrainingDataset ds = new TrainingDataset(data, expected).increaseTrueAlert(increase_ratio);
List<TrainingDataset> folds = ds.prepareFolds(fold_number);
assertEquals(fold_number, folds.size());
for (int i = 0; i < folds.size(); i++) {
assertTrue(folds.get(i).getLength()
== (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number || folds.get(i).getLength()
== 1 + (number_of_alert * increase_ratio + number_of_no_alert)
/ fold_number);
assertTrue(Utils.sumListElements(folds.get(i).getExpected())
== ((number_of_alert * increase_ratio) / fold_number) ||
Utils.sumListElements(folds.get(i).getExpected())
== 1 + ((number_of_alert * increase_ratio) / fold_number)
|| Utils.sumListElements(folds.get(i).getExpected())
== 2 + ((number_of_alert * increase_ratio) / fold_number));
}
}
static List<List<Double>> generateData(final int size, final int weight_number) {
Random rnd = new Random(5489);
List<List<Double>> data = new ArrayList<>();
for (int i = 0; i < size; i++) {
List<Double> vector = new ArrayList<>();
for (int j = 0; j < weight_number; j++) {
vector.add(rnd.nextDouble());
}
data.add(vector);
}
return data;
}
static List<Double> generateExpected(final int size) {
Random rnd = new Random(5768);
List<Double> expected = new ArrayList<>();
for (int i = 0; i < size; i++) {
expected.add(rnd.nextDouble());
}
return expected;
}
static List<Double> generateExpectedBinaryClassification(final int size) {
Random rnd = new Random(5768);
List<Double> expected = new ArrayList<>();
for (int i = 0; i < size; i++) {
if (rnd.nextDouble() <= 0.5) {
expected.add(new Double(0.0));
} else {
expected.add(new Double(1.0));
}
}
return expected;
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment