From 6f92c1d39b72296e180f80ccf6541b459a74803b Mon Sep 17 00:00:00 2001 From: Alex <croix.alexandre@gmail.com> Date: Thu, 5 Sep 2019 14:46:42 +0200 Subject: [PATCH] Add setters in HyperParemeters class --- .../java/wowa/training/HyperParameters.java | 59 +++++++++++++++++-- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/src/main/java/be/cylab/java/wowa/training/HyperParameters.java b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java index 32a9084..a5df93f 100644 --- a/src/main/java/be/cylab/java/wowa/training/HyperParameters.java +++ b/src/main/java/be/cylab/java/wowa/training/HyperParameters.java @@ -27,11 +27,11 @@ public class HyperParameters { final OptimizationAlgorithm algorithm, final Activation activation_function, final double percent_test_train) { - this.neurons_number = neurons_number; - this.learning_rate = learning_rate; - this.algorithm = algorithm; - this.activation_function = activation_function; - this.percent_test_train = percent_test_train; + setNeuronsNumber(neurons_number); + setLearningRate(learning_rate); + setAlgorithm(algorithm); + setActivationFunction(activation_function); + setPercentTestTrain(percent_test_train); } /** @@ -47,7 +47,7 @@ public class HyperParameters { final OptimizationAlgorithm algorithm, final Activation activation_function) { this(neurons_number, learning_rate, algorithm, - activation_function, 100); + activation_function, 99); } /** @@ -89,4 +89,51 @@ public class HyperParameters { public final double getPercentTestTrain() { return percent_test_train; } + + /** + * @param neurons_number + */ + public void setNeuronsNumber(final int neurons_number) { + if (neurons_number < 5 || neurons_number > 80) { + throw new IllegalArgumentException( + "Neuron number must be between 5 and 80"); + } + this.neurons_number = neurons_number; + } + + /** + * @param learning_rate + */ + public void setLearningRate(final double learning_rate) { + if (learning_rate <= 0.0 || learning_rate >= 1.0) { + throw new IllegalArgumentException( + "Learning rate must be between 0 and 1"); + } + this.learning_rate = learning_rate; + } + + /** + * @param algorithm + */ + public void setAlgorithm(final OptimizationAlgorithm algorithm) { + this.algorithm = algorithm; + } + + /** + * @param activation_function + */ + public void setActivationFunction(final Activation activation_function) { + this.activation_function = activation_function; + } + + /** + * @param percent_test_train + */ + public void setPercentTestTrain(final double percent_test_train) { + if (percent_test_train <= 0 || percent_test_train >= 100) { + throw new IllegalArgumentException( + "Percentage of train must be between 0 and 100"); + } + this.percent_test_train = percent_test_train; + } } -- GitLab