From 8151024c05d2252a33678e43b26aae9a347d7f97 Mon Sep 17 00:00:00 2001 From: Mahmoud Bentriou <moud@MacBook-Pro-de-Mahmoud.local> Date: Wed, 5 Apr 2023 10:52:49 +0200 Subject: [PATCH] Add new feature to rf abc: presimulated dataset --- algorithms/abc_model_choice.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/algorithms/abc_model_choice.jl b/algorithms/abc_model_choice.jl index 5760591..4fc9d9f 100644 --- a/algorithms/abc_model_choice.jl +++ b/algorithms/abc_model_choice.jl @@ -175,11 +175,12 @@ The mandatory arguments are: * `summary_stats_func::Function`: the function that computes the summary statistics over a model simulation. The optional arguments are: +* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset` (by default: nothing) * `models_prior`: the prior over the models (by default: discrete uniform distribution) * `k`: the k nearest samples from the observations to keep in the reference table (by default: k = N_ref) * `distance_func`: the distance function, has to be defined if k < N_ref -* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation - fit of the Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`). +* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation fit of the + Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`). See scikit-learn documentation of RandomForestClassifier for the hyperparameters name. The result is a `RandomForestABC` object with fields: @@ -192,16 +193,21 @@ The result is a `RandomForestABC` object with fields: function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}}, summary_stats_observations, summary_stats_func::Function, N_ref::Int; - models_prior::DiscreteUnivariateDistribution = Categorical([1/length(models) for i = 1:length(models)]), + abc_trainset::Union{Nothing,AbcModelChoiceDataset} = nothing, + models_prior::DiscreteUnivariateDistribution = + Categorical([1/length(models) for i = 1:length(models)]), k::Int = N_ref, distance_func::Function = (x,y) -> 1, hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])) @assert k <= N_ref - trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations, summary_stats_func, distance_func, k, N_ref) + if isnothing(abc_trainset) + abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations, + summary_stats_func, distance_func, k, N_ref) + end gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range) - fit!(gridsearch, transpose(trainset.X), trainset.y) + fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y) best_rf = gridsearch.best_estimator_ - return RandomForestABC(trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations])) + return RandomForestABC(abc_trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations])) end """ -- GitLab