Skip to content
Snippets Groups Projects
Commit 21f736b8 authored by Bentriou Mahmoud's avatar Bentriou Mahmoud
Browse files

set_x0! + tests

parent 9190a75c
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ export get_index, get_value, length_var, isaccepted ...@@ -28,7 +28,7 @@ export get_index, get_value, length_var, isaccepted
# Model related methods # Model related methods
export simulate, volatile_simulate export simulate, volatile_simulate
export distribute_mean_value_lha, mean_value_lha, distribute_prob_accept_lha export distribute_mean_value_lha, mean_value_lha, distribute_prob_accept_lha
export set_param!, set_time_bound!, set_observed_var!, observe_all! export set_param!, set_x0!, set_time_bound!, set_observed_var!, observe_all!
export get_param, getproperty, get_proba_model, get_observed_var export get_param, getproperty, get_proba_model, get_observed_var
export isbounded, isaccepted, check_consistency export isbounded, isaccepted, check_consistency
export draw_model!, draw!, fill!, prior_pdf!, prior_pdf, insupport export draw_model!, draw!, fill!, prior_pdf!, prior_pdf, insupport
......
...@@ -393,18 +393,35 @@ function observe_all!(am::Model) ...@@ -393,18 +393,35 @@ function observe_all!(am::Model)
m._g_idx = _g_idx m._g_idx = _g_idx
m._map_obs_var_idx = m.map_var_idx m._map_obs_var_idx = m.map_var_idx
end end
set_param!(m::ContinuousTimeModel, p::Vector{Float64}) = (m.p = p) function set_param!(am::Model, new_p::Vector{Float64})
set_param!(m::ContinuousTimeModel, name_p::String, p_i::Float64) = (m.p[m.map_param_idx[name_p]] = p_i) m = get_proba_model(am)
function set_param!(m::ContinuousTimeModel, l_name_p::Vector{String}, p::Vector{Float64}) @assert length(new_p) == m.k
m.p = new_p
end
function set_param!(am::Model, name_p::String, p_i::Float64)
m = get_proba_model(am)
m.p[m.map_param_idx[name_p]] = p_i
end
function set_param!(am::Model, l_name_p::Vector{String}, p::Vector{Float64})
m = get_proba_model(am)
@assert length(l_name_p) == length(p)
for i = eachindex(l_name_p) for i = eachindex(l_name_p)
set_param!(m, l_name_p[i], p[i]) set_param!(m, l_name_p[i], p[i])
end end
end end
function set_x0!(am::Model, new_x0::Vector{Int})
get_param(m::ContinuousTimeModel) = m.p m = get_proba_model(am)
getindex(m::ContinuousTimeModel, name_p::String) = m.p[m.map_param_idx[name_p]] @assert length(new_x0) == m.d
m.x0 = new_x0
end
set_time_bound!(am::Model, b::Float64) = (get_proba_model(am).time_bound = b) set_time_bound!(am::Model, b::Float64) = (get_proba_model(am).time_bound = b)
get_param(am::Model) = get_proba_model(am).p
function getindex(am::Model, name_p::String)
m = get_proba_model(am)
m.p[m.map_param_idx[name_p]]
end
function getproperty(m::ContinuousTimeModel, sym::Symbol) function getproperty(m::ContinuousTimeModel, sym::Symbol)
if sym == :dobs if sym == :dobs
return length(m.g) return length(m.g)
......
...@@ -32,6 +32,7 @@ using Test ...@@ -32,6 +32,7 @@ using Test
@test include("unit/set_param.jl") @test include("unit/set_param.jl")
@test include("unit/set_x0.jl")
@test include("unit/side_effects_1.jl") @test include("unit/side_effects_1.jl")
@test include("unit/simulate_available_models.jl") @test include("unit/simulate_available_models.jl")
@test include("unit/simulate_sir.jl") @test include("unit/simulate_sir.jl")
......
using MarkovProcesses
load_model("SIR")
load_model("ER")
test_all = true
new_x0 = [92,10,2]
set_x0!(SIR, new_x0)
test_all = test_all && SIR.x0 == new_x0
new_x0 = [1, 1, 4, 2]
set_x0!(ER, new_x0)
test_all = test_all && ER.x0 == new_x0 && SIR.x0 == [92, 10, 2]
new_x0 = [10,10,2]
set_x0!(SIR, new_x0)
test_all = test_all && SIR.x0 == new_x0 && ER.x0 == [1, 1, 4, 2]
return test_all
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment