Skip to content
Snippets Groups Projects
Commit 5a9333ee authored by Mahmoud Bentriou's avatar Mahmoud Bentriou
Browse files

New feature: new method for RF ABC model choice, when the dataset is already simulated.

parent 8151024c
No related branches found
No related tags found
No related merge requests found
Pipeline #30280 failed
......@@ -2,6 +2,7 @@
struct AbcModelChoiceDataset
models_indexes::Vector{Int}
summary_stats_matrix::Matrix
summary_stats_observations
epsilon::Float64
end
......@@ -41,6 +42,7 @@ The mandatory arguments are:
The result is a `AbcModelChoiceDataset` with fields:
* `summary_stats_matrix`: the (N_stats, N_ref) features matrix. Accessible via `.X`.
* `summary_stats_observations`: the observations used for simulating the dataset.
* `models_indexes`: the labels vector. Accessible via `.y`.
If specified, `dir_results` is the directory where the summary statistics matrix and associated models are stored (CSV).
......@@ -113,15 +115,18 @@ function _abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}
close(file_cfg)
end
return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix, distances[k_nn[end]])
return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix,
summary_stats_observations, distances[k_nn[end]])
end
function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}}, models_prior::DiscreteUnivariateDistribution,
summary_stats_observations,
summary_stats_func::Function, distance_func::Function,
k::Int, N::Int; dir_results::Union{Nothing,String} = nothing)
function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,ParametricModel}},
models_prior::DiscreteUnivariateDistribution,
summary_stats_observations,
summary_stats_func::Function,
distance_func::Function,
k::Int, N::Int; dir_results::Union{Nothing,String} = nothing)
@assert length(models) >= 2 "Should contain at least 2 models"
@assert ncategories(models_prior) == length(models) "Number of categories of models' prior and number of models do not equal"
@assert ncategories(models_prior) == length(models) "Number of models' prior categories and number of models do not equal"
models_indexes = SharedVector{Int}(N)
summary_stats_matrix = SharedMatrix{eltype(summary_stats_observations)}(length(summary_stats_observations), N)
......@@ -157,7 +162,8 @@ function _distributed_abc_model_choice_dataset(models::Vector{<:Union{Model,Para
close(file_cfg)
end
return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix, distances[k_nn[end]])
return AbcModelChoiceDataset(knn_models_indexes, knn_summary_stats_matrix,
summary_stats_observations, distances[k_nn[end]])
end
"""
......@@ -175,7 +181,6 @@ The mandatory arguments are:
* `summary_stats_func::Function`: the function that computes the summary statistics over a model simulation.
The optional arguments are:
* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset` (by default: nothing)
* `models_prior`: the prior over the models (by default: discrete uniform distribution)
* `k`: the k nearest samples from the observations to keep in the reference table (by default: k = N_ref)
* `distance_func`: the distance function, has to be defined if k < N_ref
......@@ -192,30 +197,63 @@ The result is a `RandomForestABC` object with fields:
"""
function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
summary_stats_observations,
summary_stats_func::Function, N_ref::Int;
abc_trainset::Union{Nothing,AbcModelChoiceDataset} = nothing,
summary_stats_func::Function,
N_ref::Int;
models_prior::DiscreteUnivariateDistribution =
Categorical([1/length(models) for i = 1:length(models)]),
k::Int = N_ref, distance_func::Function = (x,y) -> 1,
hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
:min_samples_split => [2]))
@assert k <= N_ref
if isnothing(abc_trainset)
abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
summary_stats_func, distance_func, k, N_ref)
end
abc_trainset = abc_model_choice_dataset(models, models_prior, summary_stats_observations,
summary_stats_func, distance_func, k, N_ref)
gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
best_rf = gridsearch.best_estimator_
return RandomForestABC(abc_trainset, best_rf, summary_stats_observations,
predict(best_rf, [summary_stats_observations]))
end
"""
rf_abc_model_choice(abc_trainset;
k::Int = N_ref, distance_func::Function = (x,y) -> 1,
hyperparameters_range::Dict)
Run the Random Forest Approximate Bayesian Computation model choice method with an already simulated dataset.
The mandatory arguments are:
* `abc_trainset`: an already simulated dataset with ``abc_model_choice_dataset`
The optional arguments are:
* `hyperparameters_range`: a dict with the hyperparameters range values for the cross validation fit of the
Random Forest (by default: `Dict(:n_estimators => [200], :min_samples_leaf => [1], :min_samples_split => [2])`).
See scikit-learn documentation of RandomForestClassifier for the hyperparameters name.
The result is a `RandomForestABC` object with fields:
* `reference_table` an AbcModelChoiceDataset that corresponds to the reference table of the algorithm,
* `clf` a random forest classifier (PyObject from scikit-learn),
* `summary_stats_observations` are the summary statitics of the observations
* `estim_model` is the underlying model of the observations inferred with the RF-ABC method.
"""
function rf_abc_model_choice(abc_trainset::MarkovProcesses.AbcModelChoiceDataset;
hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
:min_samples_split => [2]))
gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
fit!(gridsearch, transpose(abc_trainset.X), abc_trainset.y)
best_rf = gridsearch.best_estimator_
return RandomForestABC(abc_trainset, best_rf, summary_stats_observations, predict(best_rf, [summary_stats_observations]))
return RandomForestABC(abc_trainset, best_rf, abc_trainset.summary_stats_observations,
predict(best_rf, [abc_trainset.summary_stats_observations]))
end
"""
posterior_proba_model(rf_abc::RandomForestABC)
Estimates the posterior probability of the model ``P(M = \\widehat{M}(s_{obs}) | s_{obs})`` with the Random Forest ABC method.
"""
function posterior_proba_model(rf_abc::RandomForestABC)
@assert rf_abc.summary_stats_observations == rf_abc.reference_table.summary_stats_observations
oob_votes = rf_abc.clf.oob_decision_function_
y_pred_oob = argmax.([oob_votes[i,:] for i = 1:size(oob_votes)[1]])
y_oob_regression = y_pred_oob .!= rf_abc.reference_table.y
......
......@@ -44,7 +44,8 @@ observations = simulate(m3)
ss_observations = ss_func(observations)
models = [m1, m2, m3]
println("Testset 10000 samples")
@timev abc_testset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 10000, 10000; dir_results = "toy_ex")
@timev abc_testset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 10000, 10000;
dir_results = "toy_ex")
list_lh = [lh_m1, lh_m2, lh_m3]
prob_model(ss, list_lh, idx_model) = list_lh[idx_model](ss) / sum([list_lh[i](ss) for i = eachindex(list_lh)])
......@@ -72,11 +73,16 @@ savefig("set.svg")
grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2], :n_jobs => [8])
println("RF ABC")
# When rf_abc_model_choice simulates the abc dataset
@timev res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
@show posterior_proba_model(res_rf_abc)
X_testset = transpose(abc_testset.X)
println(classification_report(y_true = abc_testset.y, y_pred = predict(res_rf_abc.clf, X_testset)))
@show accuracy_score(abc_testset.y, predict(res_rf_abc.clf, X_testset))
# When rf_abc_model_choice uses an already simulated dataset
@timev abc_dataset = abc_model_choice_dataset(models, ss_observations, ss_func, dist_l2, 29000, 29000)
@timev res_rf_abc = rf_abc_model_choice(abc_dataset; hyperparameters_range = grid)
@show posterior_proba_model(res_rf_abc)
return true
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment