diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl index a12a622e6d1189396843acfa9827f284f1c60df4..36b234fd3ff565fd87ccb85be62813ad12beb709 100644 --- a/core/MarkovProcesses.jl +++ b/core/MarkovProcesses.jl @@ -28,7 +28,7 @@ export get_index, get_value, length_var, isaccepted # Model related methods export simulate, volatile_simulate export distribute_mean_value_lha, mean_value_lha, distribute_prob_accept_lha -export set_param!, set_time_bound!, set_observed_var!, observe_all! +export set_param!, set_x0!, set_time_bound!, set_observed_var!, observe_all! export get_param, getproperty, get_proba_model, get_observed_var export isbounded, isaccepted, check_consistency export draw_model!, draw!, fill!, prior_pdf!, prior_pdf, insupport diff --git a/core/model.jl b/core/model.jl index 9350f3abb9dd81bf72b5e7ef4192956c02bc837b..75874f1e159f5fa62cf18624efc349d596d780c1 100644 --- a/core/model.jl +++ b/core/model.jl @@ -393,18 +393,35 @@ function observe_all!(am::Model) m._g_idx = _g_idx m._map_obs_var_idx = m.map_var_idx end -set_param!(m::ContinuousTimeModel, p::Vector{Float64}) = (m.p = p) -set_param!(m::ContinuousTimeModel, name_p::String, p_i::Float64) = (m.p[m.map_param_idx[name_p]] = p_i) -function set_param!(m::ContinuousTimeModel, l_name_p::Vector{String}, p::Vector{Float64}) +function set_param!(am::Model, new_p::Vector{Float64}) + m = get_proba_model(am) + @assert length(new_p) == m.k + m.p = new_p +end +function set_param!(am::Model, name_p::String, p_i::Float64) + m = get_proba_model(am) + m.p[m.map_param_idx[name_p]] = p_i +end +function set_param!(am::Model, l_name_p::Vector{String}, p::Vector{Float64}) + m = get_proba_model(am) + @assert length(l_name_p) == length(p) for i = eachindex(l_name_p) set_param!(m, l_name_p[i], p[i]) end end - -get_param(m::ContinuousTimeModel) = m.p -getindex(m::ContinuousTimeModel, name_p::String) = m.p[m.map_param_idx[name_p]] +function set_x0!(am::Model, new_x0::Vector{Int}) + m = get_proba_model(am) + @assert length(new_x0) == m.d + m.x0 = new_x0 +end set_time_bound!(am::Model, b::Float64) = (get_proba_model(am).time_bound = b) + +get_param(am::Model) = get_proba_model(am).p +function getindex(am::Model, name_p::String) + m = get_proba_model(am) + m.p[m.map_param_idx[name_p]] +end function getproperty(m::ContinuousTimeModel, sym::Symbol) if sym == :dobs return length(m.g) diff --git a/tests/run_unit.jl b/tests/run_unit.jl index 24cd7f7125e30057dc0bcc1282dc298f57385aa9..251f35c9ec153b2da246e9f2bf04aa47befd9941 100644 --- a/tests/run_unit.jl +++ b/tests/run_unit.jl @@ -32,6 +32,7 @@ using Test @test include("unit/set_param.jl") + @test include("unit/set_x0.jl") @test include("unit/side_effects_1.jl") @test include("unit/simulate_available_models.jl") @test include("unit/simulate_sir.jl") diff --git a/tests/unit/set_x0.jl b/tests/unit/set_x0.jl new file mode 100644 index 0000000000000000000000000000000000000000..993373c093eb241c605da96a84b701c55926ea77 --- /dev/null +++ b/tests/unit/set_x0.jl @@ -0,0 +1,21 @@ + +using MarkovProcesses + +load_model("SIR") +load_model("ER") + +test_all = true +new_x0 = [92,10,2] +set_x0!(SIR, new_x0) +test_all = test_all && SIR.x0 == new_x0 + +new_x0 = [1, 1, 4, 2] +set_x0!(ER, new_x0) +test_all = test_all && ER.x0 == new_x0 && SIR.x0 == [92, 10, 2] + +new_x0 = [10,10,2] +set_x0!(SIR, new_x0) +test_all = test_all && SIR.x0 == new_x0 && ER.x0 == [1, 1, 4, 2] + +return test_all +