Skip to content
Snippets Groups Projects
Commit 6ea9c0b0 authored by a.croix's avatar a.croix
Browse files

Add new fils for new dataset

parent 6d1239ab
No related branches found
No related tags found
1 merge request!4Neural network
Pipeline #2433 failed
This diff is collapsed.
This diff is collapsed.
......@@ -28,8 +28,8 @@ public final class MainTest {
= TrainerParameters.SELECTION_METHOD_RWS;
int generation_population_method
= TrainerParameters.POPULATION_INITIALIZATION_RANDOM;
String data_file = "ressources/webshell_data.json";
String expected_file = "ressources/webshell_expected.json";
String data_file = "./ressources/webshell_data_new_version.json";
String expected_file = "ressources/webshell_expected_new_version.json";
List<List<Double>> data
= Utils.convertJsonToDataForTrainer(data_file);
List<Double> expected
......
......@@ -4,9 +4,12 @@ import com.owlike.genson.GenericType;
import com.owlike.genson.Genson;
import info.debatty.java.aggregation.WOWA;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.datasets.iterator.DoublesDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.FileOutputStream;
import java.io.IOException;
......@@ -599,4 +602,36 @@ final class Utils {
INDArray expected_ind = Nd4j.create(expected_array);
return new DataSet(data_ind, expected_ind);
}
static DataSetIterator createDataSetIterator(
final List<List<Double>> data,
final List<Double> expected,
final int batch_size
) {
List<Pair<double[], double[]>> ds = new ArrayList<>();
if (data.size() != expected.size()) {
throw new IllegalArgumentException(
"Data and Expected must have the same size");
}
for (int i = 0; i < data.size(); i++) {
if (expected.get(i) == 1.0) {
double[] d = convertListDoubleToArrayDouble(data.get(i));
double[] e = {1.0, 0.0};
Pair<double[], double[]> p
= new Pair<>(d, e);
ds.add(p);
} else {
double[] d = convertListDoubleToArrayDouble(data.get(i));
double[] e = {0.0, 1.0};
Pair<double[], double[]> p
= new Pair<>(d, e);
ds.add(p);
}
}
DataSetIterator dsi = new DoublesDataSetIterator(ds, batch_size);
return dsi;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment