From 7d3b708955e3b030f81aa761ba50f9b2313c129a Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Sun, 16 May 2021 08:14:12 +0200
Subject: [PATCH] small changes

---
 algorithms/abc_model_choice.jl        | 6 +++---
 tests/abc_model_choice/toy_example.jl | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/algorithms/abc_model_choice.jl b/algorithms/abc_model_choice.jl
index 4ccdc99..44ee0ce 100644
--- a/algorithms/abc_model_choice.jl
+++ b/algorithms/abc_model_choice.jl
@@ -64,8 +64,8 @@ function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
                              summary_stats_observations,
                              summary_stats_func::Function, N_ref::Int;
                              k::Int = N_ref, distance_func::Function = (x,y) -> 1, 
-                             hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1, 2],
-                                                                :min_samples_split => [2,5]))
+                             hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
+                                                                :min_samples_split => [2]))
     @assert k <= N_ref
     trainset = abc_model_choice_dataset(models, summary_stats_observations, summary_stats_func, distance_func, k, N_ref)
     gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
@@ -80,7 +80,7 @@ function posterior_proba_model(rf_abc::RandomForestABC)
     y_pred_oob = argmax.([oob_votes[i,:] for i = 1:size(oob_votes)[1]])
     y_oob_regression = y_pred_oob .!= rf_abc.reference_table.y
     dict_params = Dict()
-    for param in ["n_estimators", "min_samples_leaf", "min_samples_split", "oob_score"]
+    for param in ["n_estimators", "min_samples_leaf", "min_samples_split", "oob_score", "n_jobs"]
         dict_params[Symbol(param)] = get_params(rf_abc.clf)[param]
     end
     rf_regressor = RandomForestRegressor(;dict_params...)
diff --git a/tests/abc_model_choice/toy_example.jl b/tests/abc_model_choice/toy_example.jl
index 6346711..3cdef11 100644
--- a/tests/abc_model_choice/toy_example.jl
+++ b/tests/abc_model_choice/toy_example.jl
@@ -33,7 +33,7 @@ lh_m2(s) = exp(-s[2]^2/(2n*(n+1)) - (s[3]^2)/2 + (s[2]^2)/(2n) - s[2]) * (2pi)^(
 lh_m3(s) = exp(s[2])*gamma(2n+1)/gamma(2)^n * (1+s[1])^(-2n-1)
 
 ss_func(y) = [sum(y), sum(log.(y)), sum(log.(y).^2)]
-dist_l2(s_sim,s_obs) = sqrt(dot(s_sim,s_obs))
+dist_l2(s_sim,s_obs) = norm(s_sim-s_obs)
 
 observations = simulate(m3)
 ss_observations = ss_func(observations)
@@ -64,8 +64,8 @@ end
 savefig("set.svg")
 =#
 
-grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2])
-res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
+grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2], :n_jobs => [8])
+@timev res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
 @show posterior_proba_model(res_rf_abc) 
 println(classification_report(y_true = abc_testset.y, y_pred = predict(res_rf_abc.clf, abc_testset.X)))
 @show accuracy_score(abc_testset.y, predict(res_rf_abc.clf, abc_testset.X))
-- 
GitLab