Skip to content
Snippets Groups Projects
Commit 7d3b7089 authored by Bentriou Mahmoud's avatar Bentriou Mahmoud
Browse files

small changes

parent 97df1230
No related branches found
No related tags found
No related merge requests found
......@@ -64,8 +64,8 @@ function rf_abc_model_choice(models::Vector{<:Union{Model,ParametricModel}},
summary_stats_observations,
summary_stats_func::Function, N_ref::Int;
k::Int = N_ref, distance_func::Function = (x,y) -> 1,
hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1, 2],
:min_samples_split => [2,5]))
hyperparameters_range::Dict = Dict(:n_estimators => [200], :min_samples_leaf => [1],
:min_samples_split => [2]))
@assert k <= N_ref
trainset = abc_model_choice_dataset(models, summary_stats_observations, summary_stats_func, distance_func, k, N_ref)
gridsearch = GridSearchCV(RandomForestClassifier(oob_score=true), hyperparameters_range)
......@@ -80,7 +80,7 @@ function posterior_proba_model(rf_abc::RandomForestABC)
y_pred_oob = argmax.([oob_votes[i,:] for i = 1:size(oob_votes)[1]])
y_oob_regression = y_pred_oob .!= rf_abc.reference_table.y
dict_params = Dict()
for param in ["n_estimators", "min_samples_leaf", "min_samples_split", "oob_score"]
for param in ["n_estimators", "min_samples_leaf", "min_samples_split", "oob_score", "n_jobs"]
dict_params[Symbol(param)] = get_params(rf_abc.clf)[param]
end
rf_regressor = RandomForestRegressor(;dict_params...)
......
......@@ -33,7 +33,7 @@ lh_m2(s) = exp(-s[2]^2/(2n*(n+1)) - (s[3]^2)/2 + (s[2]^2)/(2n) - s[2]) * (2pi)^(
lh_m3(s) = exp(s[2])*gamma(2n+1)/gamma(2)^n * (1+s[1])^(-2n-1)
ss_func(y) = [sum(y), sum(log.(y)), sum(log.(y).^2)]
dist_l2(s_sim,s_obs) = sqrt(dot(s_sim,s_obs))
dist_l2(s_sim,s_obs) = norm(s_sim-s_obs)
observations = simulate(m3)
ss_observations = ss_func(observations)
......@@ -64,8 +64,8 @@ end
savefig("set.svg")
=#
grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2])
res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
grid = Dict(:n_estimators => [500], :min_samples_leaf => [1], :min_samples_split => [2], :n_jobs => [8])
@timev res_rf_abc = rf_abc_model_choice(models, ss_observations, ss_func, 29000; hyperparameters_range = grid)
@show posterior_proba_model(res_rf_abc)
println(classification_report(y_true = abc_testset.y, y_pred = predict(res_rf_abc.clf, abc_testset.X)))
@show accuracy_score(abc_testset.y, predict(res_rf_abc.clf, abc_testset.X))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment