From 5a9333eeb15437532b6690eee411226d79e32254 Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <moud@MacBook-Pro-de-Mahmoud.local>
Date: Wed, 5 Apr 2023 16:00:20 +0200
Subject: [PATCH] New feature: new method for RF ABC model choice, when the
 dataset is already simulated.

---
 algorithms/abc_model_choice.jl       | 68 ++++++++++++++++++++++------
 test/abc_model_choice/toy_example.jl |  8 +++-
 2 files changed, 60 insertions(+), 16 deletions(-)

diff --git a/algorithms/abc_model_choice.jl b/algorithms/abc_model_choice.jl
index 4fc9d9f..f346ffe 100644
--- a/algorithms/abc_model_choice.jl
+++ b/algorithms/abc_model_choice.jl
@@ -2,6 +2,7 @@
 struct AbcModelChoiceDataset
     models_indexes::Vector{Int}
     summary_stats_matrix::Matrix
+    summary_stats_observations
     epsilon::Float64
 end
 
@@ -41,6 +42,7 @@ The mandatory arguments are:
 
 The result is a `AbcModelChoiceDataset` with fields:
 * `summary_stats_matrix`: the (N_stats, N_ref) features matrix. Accessible via `.X`.
+* `summary_stats_observations`: the observations used for simulating the dataset.
 * `models_indexes`: the labels vector. Accessible via `.y`.
 
 If specified, `dir_results` is the directory where the summary statistics matrix and associated models are stored (CSV).
@@ -113,15 +115,18 @@ function _abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}
         close(file_cfg)
     end
 
-    return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix, distances[k_nn[end]])
+    return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix,
+                                 summary_stats_observations, distances[k_nn[end]])
 end
 
-function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}}, models_prior::DiscreteUnivariateDistribution,
-                                              summary_stats_observations,
-                                              summary_stats_func::Function, distance_func::Function,
-                                              k::Int, N::Int; dir_results::Union{Nothing,String} = nothing)
+function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}},
+                                               models_prior::DiscreteUnivariateDistribution,
+                                               summary_stats_observations,
+                                               summary_stats_func::Function,
+                                               distance_func::Function,
+                                               k::Int, N::Int; dir_results::Union{Nothing,String} = nothing)
     @assert length(models) >= 2 "Should contain at least 2 models"
-    @assert ncategories(models_prior) == length(models) "Number of categories of models' prior and number of models do not equal"
+    @assert ncategories(models_prior) == length(models) "Number of models' prior categories and number of models do not equal"
 
     models_indexes = SharedVector{Int}(N)
     summary_stats_matrix = SharedMatrix{eltype(summary_stats_observations)}(length(summary_stats_observations), N)
@@ -157,7 +162,8 @@ function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,Para
         close(file_cfg)
     end
 
-    return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix, distances[k_nn[end]])
+    return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix,
+                                 summary_stats_observations, distances[k_nn[end]])
 end
 
 """
@@ -175,7 +181,6 @@ The mandatory arguments are:
 * `summary_stats_func::Function`: the function that computes the summary statistics over a model simulation.
 
 The optional arguments are:
-* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset` (by default: nothing)
 * `models_prior`: the prior over the models (by default: discrete uniform distribution)
 * `k`: the k nearest samples from the observations to keep in the reference table (by default: k = N_ref)
 * `distance_func`: the distance function, has to be defined if k < N_ref
@@ -192,30 +197,63 @@ The result is a `RandomForestABC` object with fields:
 """
 function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
                              summary_stats_observations,
-                             summary_stats_func::Function, N_ref::Int;
-                             abc_trainset::Union{Nothing,AbcModelChoiceDataset} = nothing,
+                             summary_stats_func::Function, 
+                             N_ref::Int;
                              models_prior::DiscreteUnivariateDistribution = 
                                 Categorical([1/length(models) for i = 1:length(models)]),
                              k::Int = N_ref, distance_func::Function = (x,y) -> 1, 
                              hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
                                                                 :min_samples_split => [2]))
     @assert k <= N_ref
-    if isnothing(abc_trainset)
-        abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
-                                                summary_stats_func, distance_func, k, N_ref)
-    end
+    abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
+                                            summary_stats_func, distance_func, k, N_ref)
+    gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
+    fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
+    best_rf = gridsearch.best_estimator_
+    return RandomForestABC(abc_trainset, best_rf, summary_stats_observations,
+                           predict(best_rf, [summary_stats_observations]))
+end
+
+"""
+    rf_abc_model_choice(abc_trainset;
+                        k::Int = N_ref, distance_func::Function = (x,y) -> 1, 
+                        hyperparameters_range::Dict)
+
+Run the Random Forest Approximate Bayesian Computation model choice method with an already simulated dataset.
+
+The mandatory arguments are:
+* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset`
+
+The optional arguments are:
+* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation fit of the 
+    Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
+    See scikit-learn documentation of RandomForestClassifier for the hyperparameters name.
+
+The result is a `RandomForestABC` object with fields:
+* `reference_table` an AbcModelChoiceDataset that corresponds to the reference table of the algorithm, 
+* `clf` a random forest classifier (PyObject from scikit-learn),
+* `summary_stats_observations` are the summary statitics of the observations
+* `estim_model` is the underlying model of the observations inferred with the RF-ABC method.
+
+"""
+function rf_abc_model_choice(abc_trainset::MarkovProcesses.AbcModelChoiceDataset;
+                             hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
+                                                                :min_samples_split => [2]))
     gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
     fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
     best_rf = gridsearch.best_estimator_
-    return RandomForestABC(abc_trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
+    return RandomForestABC(abc_trainset, best_rf, abc_trainset.summary_stats_observations,
+                           predict(best_rf, [abc_trainset.summary_stats_observations]))
 end
 
+
 """
     posterior_proba_model(rf_abc::RandomForestABC)
 
 Estimates the posterior probability of the model ``P(M = \\widehat{M}(s_{obs}) | s_{obs})`` with the Random Forest ABC method.
 """
 function posterior_proba_model(rf_abc::RandomForestABC)
+    @assert rf_abc.summary_stats_observations == rf_abc.reference_table.summary_stats_observations
     oob_votes = rf_abc.clf.oob_decision_function_
     y_pred_oob = argmax.([oob_votes[i,:] for i = 1:size(oob_votes)[1]])
     y_oob_regression = y_pred_oob .!= rf_abc.reference_table.y
diff --git a/test/abc_model_choice/toy_example.jl b/test/abc_model_choice/toy_example.jl
index c7f8f70..92da7d5 100644
--- a/test/abc_model_choice/toy_example.jl
+++ b/test/abc_model_choice/toy_example.jl
@@ -44,7 +44,8 @@ observations = simulate(m3)
 ss_observations = ss_func(observations)
 models = [m1, m2, m3]
 println("Testset 10000 samples")
-@timev abc_testset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 10000, 10000; dir_results = "toy_ex")
+@timev abc_testset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 10000, 10000;
+                                              dir_results = "toy_ex")
 
 list_lh = [lh_m1, lh_m2, lh_m3]
 prob_model(ss, list_lh, idx_model) = list_lh[idx_model](ss) / sum([list_lh[i](ss) for i = eachindex(list_lh)])
@@ -72,11 +73,16 @@ savefig("set.svg")
 
 grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2], :n_jobs => [8])
 println("RF ABC")
+# When rf_abc_model_choice simulates the abc dataset
 @timev res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
 @show posterior_proba_model(res_rf_abc)
 X_testset = transpose(abc_testset.X)
 println(classification_report(y_true = abc_testset.y, y_pred = predict(res_rf_abc.clf, X_testset)))
 @show accuracy_score(abc_testset.y, predict(res_rf_abc.clf, X_testset))
+# When rf_abc_model_choice uses an already simulated dataset
+@timev abc_dataset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 29000, 29000)
+@timev res_rf_abc = rf_abc_model_choice(abc_dataset; hyperparameters_range = grid)
+@show posterior_proba_model(res_rf_abc)
 
 return true
 
-- 
GitLab