diff --git a/src/main/java/be/cylab/java/wowa/training/Example.java b/src/main/java/be/cylab/java/wowa/training/Example.java index e93c03cab8f0e295578b1db4de5cd1dabc23763a..c817f0f4372db80b4004658840c39ef5b9077816 100644 --- a/src/main/java/be/cylab/java/wowa/training/Example.java +++ b/src/main/java/be/cylab/java/wowa/training/Example.java @@ -52,12 +52,24 @@ public final class Example { throw new IllegalArgumentException( "Initialization must be RANDOM or QUASI_RANDOM"); } - int fold_number = Integer.parseInt(args[6]); - int increase_ratio = Integer.parseInt(args[7]); + String solution_type = args[6]; + AbstractSolution sol_type = null; + if (solution_type.matches("DISTANCE")) { + sol_type = new SolutionDistance(5); + } else if (solution_type.matches("AUC")) { + sol_type = new SolutionAUC(5); + } else { + throw new IllegalArgumentException( + "Solution type must be Distance or AUC"); + } + int fold_number = Integer.parseInt(args[7]); + int increase_ratio = Integer.parseInt(args[8]); - String data_file = "./ressources/webshell_data.json"; + + + String data_file = "./ressources/webshell_data_new_version.json"; String expected_file - = "./ressources/webshell_expected.json"; + = "./ressources/webshell_expected_new_version.json"; Logger logger = Logger.getLogger(Trainer.class.getName()); logger.setLevel(Level.INFO); StreamHandler handler = new StreamHandler(System.out, @@ -70,7 +82,7 @@ public final class Example { max_generation_number, selection_method, generation_population_method); - Trainer trainer = new Trainer(parameters, new SolutionDistance(5)); + Trainer trainer = new Trainer(parameters, sol_type); /* AbstractSolution solution = trainer.run( data_file, @@ -79,8 +91,9 @@ public final class Example { System.out.println("Run Cross validation"); */ + long start_time = System.currentTimeMillis(); - HashMap<AbstractSolution, double[]> solutions = trainer.runKFold( + HashMap<AbstractSolution, double[]> solutions = trainer.runKFoldCSV( data_file, expected_file, fold_number, @@ -110,18 +123,17 @@ public final class Example { } catch (IOException e) { e.printStackTrace(); } - - /* - AbstractSolution sol = trainer.run(data_file, expected_file); - List<List<Double>> data = Utils.convertJsonToDataForTrainer( + + AbstractSolution sol = trainer.runCSV(data_file, expected_file); + List<List<Double>> data = Utils.convertCSVToDataForTrainer( data_file); - List<Double> expected = Utils.convertJsonToExpectedForTrainer( + List<Double> expected = Utils.convertCSVToExpectedForTrainer( expected_file); WOWA wowa = new WOWA( Utils.convertListDoubleToArrayDouble(sol.getWeightsW()), Utils.convertListDoubleToArrayDouble(sol.getWeightsP())); - FileWriter csv = new FileWriter("CSV.csv"); + FileWriter csv = new FileWriter("CSV2.csv"); for (int i = 0; i < expected.size(); i++) { double[] data_array = Utils.convertListDoubleToArrayDouble( data.get(i)); @@ -134,8 +146,8 @@ public final class Example { } csv.flush(); csv.close(); -*/ +*/ } } diff --git a/src/main/java/be/cylab/java/wowa/training/Trainer.java b/src/main/java/be/cylab/java/wowa/training/Trainer.java index 3a882f7a5c71382a1ebdac36f4fbc47a0e4cc07e..dde961695c6bae24c19e4921546528e4bcdd5921 100644 --- a/src/main/java/be/cylab/java/wowa/training/Trainer.java +++ b/src/main/java/be/cylab/java/wowa/training/Trainer.java @@ -1,11 +1,13 @@ package be.cylab.java.wowa.training; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Random; +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; import java.util.logging.Level; /** @@ -47,6 +49,22 @@ public class Trainer { return this.run(data, expected); } + /** + * Run with CSV files as arguments. + * @param data_file + * @param expected_file + * @return + */ + public final AbstractSolution runCSV( + final String data_file, + final String expected_file) { + List<List<Double>> data = Utils.convertCSVToDataForTrainer(data_file); + List<Double> expected + = Utils.convertCSVToExpectedForTrainer(expected_file); + return run(data, expected); + + } + /** * @param data * @param expected @@ -137,6 +155,8 @@ public class Trainer { map.put(sol, out); parameters.getLogger().log(Level.INFO, "Solution fold " + i + ":" + sol + " - AUC: " + auc); + parameters.getLogger().log(Level.INFO, + "Solution fold " + i + ":" + sol + " - AUC-PR: " + prauc); } return map; @@ -162,6 +182,55 @@ public class Trainer { return runKFold(data, expected, fold_number, increase_ratio_alert); } + /** + * RunKFold with csv files as arguments. + * @param data_file_name + * @param expected_file_name + * @param fold_number + * @param increase_ration_alert + * @return + */ + public final HashMap<AbstractSolution, double[]> runKFoldCSV( + final String data_file_name, + final String expected_file_name, + final int fold_number, + final int increase_ration_alert) { + List<List<Double>> data = new ArrayList<>(); + List<Double> expected = new ArrayList<>(); + Path path_to_data = Paths.get(data_file_name); + Path path_to_expected = Paths.get(expected_file_name); + try (BufferedReader br = Files.newBufferedReader(path_to_data, + StandardCharsets.UTF_8)) { + String line = br.readLine(); + while (line != null) { + String[] data_line = line.split(","); + List<Double> elements = new ArrayList<>(); + for (int i = 0; i < data_line.length; i++) { + Double el = Double.parseDouble(data_line[i]); + elements.add(el); + } + data.add(elements); + line = br.readLine(); + } + } catch (IOException e) { + e.printStackTrace(); + } + + try (BufferedReader br = Files.newBufferedReader(path_to_expected, + StandardCharsets.UTF_8)) { + String line = br.readLine(); + while (line != null) { + Double el = Double.parseDouble(line); + expected.add(el); + line = br.readLine(); + } + } catch (IOException e) { + e.printStackTrace(); + } + return runKFold(data, expected, fold_number, increase_ration_alert); + + } + /** * @param prepared_folds * @param increase_ratio diff --git a/src/main/java/be/cylab/java/wowa/training/Utils.java b/src/main/java/be/cylab/java/wowa/training/Utils.java index 2fa2915cfe770f61f9ca53a3d5743b97ef9ab1a9..4f58ad8e699981540459a484a184538f1a6c0738 100644 --- a/src/main/java/be/cylab/java/wowa/training/Utils.java +++ b/src/main/java/be/cylab/java/wowa/training/Utils.java @@ -21,11 +21,13 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; +import java.io.BufferedReader; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; @@ -156,6 +158,35 @@ final class Utils { return data; } + /** + * Import CSV file and convert in List. + * + * @param filename + * @return + */ + static List<List<Double>> convertCSVToDataForTrainer( + final String filename) { + Path path_to_data = Paths.get(filename); + List<List<Double>> data = new ArrayList<>(); + try (BufferedReader br = Files.newBufferedReader(path_to_data, + StandardCharsets.UTF_8)) { + String line = br.readLine(); + while (line != null) { + String[] data_line = line.split(","); + List<Double> elements = new ArrayList<>(); + for (int i = 0; i < data_line.length; i++) { + Double el = Double.parseDouble(data_line[i]); + elements.add(el); + } + data.add(elements); + line = br.readLine(); + } + } catch (IOException e) { + e.printStackTrace(); + } + return data; + } + /** * Import result file. * @@ -173,6 +204,31 @@ final class Utils { return expected; } + /** + * Import expected from CSV file and convert to List. + * + * @param filename + * @return + */ + static List<Double> convertCSVToExpectedForTrainer( + final String filename) { + Path path_to_expected = Paths.get(filename); + List<Double> expected = new ArrayList<>(); + try (BufferedReader br = Files.newBufferedReader(path_to_expected, + StandardCharsets.UTF_8)) { + String line = br.readLine(); + while (line != null) { + Double el = Double.parseDouble(line); + expected.add(el); + line = br.readLine(); + } + } catch (IOException e) { + e.printStackTrace(); + } + return expected; + } + + /** * Method to read file and convert into a String. * @@ -204,12 +260,12 @@ final class Utils { return true_alert; } - /** - * Function to convert a List of doubles to an Array of doubles. - * - * @param elements - * @return - */ + /** + * Function to convert a List of doubles to an Array of doubles. + * + * @param elements + * @return + */ static double[] convertListDoubleToArrayDouble( final List<Double> elements) { Double[] w = new Double[elements.size()];