From 098f8c9e3d6adfc2a85adfc9550f670d38da7beb Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <moud@MacBook-Pro-de-Mahmoud.local>
Date: Wed, 5 Apr 2023 10:38:33 +0200
Subject: [PATCH] Add new feature to rf abc: presimulated dataset

---
 algorithms/abc_model_choice.jl | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/algorithms/abc_model_choice.jl b/algorithms/abc_model_choice.jl
index 5760591..4fc9d9f 100644
--- a/algorithms/abc_model_choice.jl
+++ b/algorithms/abc_model_choice.jl
@@ -175,11 +175,12 @@ The mandatory arguments are:
 * `summary_stats_func::Function`: the function that computes the summary statistics over a model simulation.
 
 The optional arguments are:
+* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset` (by default: nothing)
 * `models_prior`: the prior over the models (by default: discrete uniform distribution)
 * `k`: the k nearest samples from the observations to keep in the reference table (by default: k = N_ref)
 * `distance_func`: the distance function, has to be defined if k < N_ref
-* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation
-    fit of the Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
+* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation fit of the 
+    Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
     See scikit-learn documentation of RandomForestClassifier for the hyperparameters name.
 
 The result is a `RandomForestABC` object with fields:
@@ -192,16 +193,21 @@ The result is a `RandomForestABC` object with fields:
 function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
                              summary_stats_observations,
                              summary_stats_func::Function, N_ref::Int;
-                             models_prior::DiscreteUnivariateDistribution = Categorical([1/length(models) for i = 1:length(models)]),
+                             abc_trainset::Union{Nothing,AbcModelChoiceDataset} = nothing,
+                             models_prior::DiscreteUnivariateDistribution = 
+                                Categorical([1/length(models) for i = 1:length(models)]),
                              k::Int = N_ref, distance_func::Function = (x,y) -> 1, 
                              hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
                                                                 :min_samples_split => [2]))
     @assert k <= N_ref
-    trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations, summary_stats_func, distance_func, k, N_ref)
+    if isnothing(abc_trainset)
+        abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
+                                                summary_stats_func, distance_func, k, N_ref)
+    end
     gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
-    fit!(gridsearch, transpose(trainset.X), trainset.y)
+    fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
     best_rf = gridsearch.best_estimator_
-    return RandomForestABC(trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
+    return RandomForestABC(abc_trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
 end
 
 """
-- 
GitLab