Skip to content
Snippets Groups Projects
Commit a23cea60 authored by Mahmoud Bentriou's avatar Mahmoud Bentriou Committed by Mahmoud Bentriou
Browse files

Add new feature to rf abc: presimulated dataset

parent a156ee74
No related branches found
No related tags found
No related merge requests found
......@@ -175,11 +175,12 @@ The mandatory arguments are:
* `summary_stats_func::Function`: the function that computes the summary statistics over a model simulation.
The optional arguments are:
* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset` (by default: nothing)
* `models_prior`: the prior over the models (by default: discrete uniform distribution)
* `k`: the k nearest samples from the observations to keep in the reference table (by default: k = N_ref)
* `distance_func`: the distance function, has to be defined if k < N_ref
* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation
fit of the Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation fit of the
Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
See scikit-learn documentation of RandomForestClassifier for the hyperparameters name.
The result is a `RandomForestABC` object with fields:
......@@ -192,16 +193,21 @@ The result is a `RandomForestABC` object with fields:
function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
summary_stats_observations,
summary_stats_func::Function, N_ref::Int;
models_prior::DiscreteUnivariateDistribution = Categorical([1/length(models) for i = 1:length(models)]),
abc_trainset::Union{Nothing,AbcModelChoiceDataset} = nothing,
models_prior::DiscreteUnivariateDistribution =
Categorical([1/length(models) for i = 1:length(models)]),
k::Int = N_ref, distance_func::Function = (x,y) -> 1,
hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
:min_samples_split => [2]))
@assert k <= N_ref
trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations, summary_stats_func, distance_func, k, N_ref)
if isnothing(abc_trainset)
abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
summary_stats_func, distance_func, k, N_ref)
end
gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
fit!(gridsearch, transpose(trainset.X), trainset.y)
fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
best_rf = gridsearch.best_estimator_
return RandomForestABC(trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
return RandomForestABC(abc_trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
end
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment