From 21f736b86df790f027fef68b3cce49659111cdf8 Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Thu, 3 Dec 2020 17:18:28 +0100
Subject: [PATCH] set_x0! + tests

---
 core/MarkovProcesses.jl |  2 +-
 core/model.jl           | 29 +++++++++++++++++++++++------
 tests/run_unit.jl       |  1 +
 tests/unit/set_x0.jl    | 21 +++++++++++++++++++++
 4 files changed, 46 insertions(+), 7 deletions(-)
 create mode 100644 tests/unit/set_x0.jl

diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl
index a12a622..36b234f 100644
--- a/core/MarkovProcesses.jl
+++ b/core/MarkovProcesses.jl
@@ -28,7 +28,7 @@ export get_index, get_value, length_var, isaccepted
 # Model related methods
 export simulate, volatile_simulate
 export distribute_mean_value_lha, mean_value_lha, distribute_prob_accept_lha
-export set_param!, set_time_bound!, set_observed_var!, observe_all!
+export set_param!, set_x0!, set_time_bound!, set_observed_var!, observe_all!
 export get_param, getproperty, get_proba_model, get_observed_var
 export isbounded, isaccepted, check_consistency
 export draw_model!, draw!, fill!, prior_pdf!, prior_pdf, insupport
diff --git a/core/model.jl b/core/model.jl
index 9350f3a..75874f1 100644
--- a/core/model.jl
+++ b/core/model.jl
@@ -393,18 +393,35 @@ function observe_all!(am::Model)
     m._g_idx = _g_idx
     m._map_obs_var_idx = m.map_var_idx
 end
-set_param!(m::ContinuousTimeModel, p::Vector{Float64}) = (m.p = p)
-set_param!(m::ContinuousTimeModel, name_p::String, p_i::Float64) = (m.p[m.map_param_idx[name_p]] = p_i)
-function set_param!(m::ContinuousTimeModel, l_name_p::Vector{String}, p::Vector{Float64}) 
+function set_param!(am::Model, new_p::Vector{Float64})
+    m = get_proba_model(am)
+    @assert length(new_p) == m.k
+    m.p = new_p
+end
+function set_param!(am::Model, name_p::String, p_i::Float64) 
+    m = get_proba_model(am)
+    m.p[m.map_param_idx[name_p]] = p_i
+end
+function set_param!(am::Model, l_name_p::Vector{String}, p::Vector{Float64}) 
+    m = get_proba_model(am)
+    @assert length(l_name_p) == length(p)
     for i = eachindex(l_name_p)
         set_param!(m, l_name_p[i], p[i])
     end
 end
-
-get_param(m::ContinuousTimeModel) = m.p
-getindex(m::ContinuousTimeModel, name_p::String) = m.p[m.map_param_idx[name_p]]
+function set_x0!(am::Model, new_x0::Vector{Int})
+    m = get_proba_model(am)
+    @assert length(new_x0) == m.d
+    m.x0 = new_x0
+end
 set_time_bound!(am::Model, b::Float64) = (get_proba_model(am).time_bound = b)
 
+
+get_param(am::Model) = get_proba_model(am).p
+function getindex(am::Model, name_p::String)
+    m = get_proba_model(am)
+    m.p[m.map_param_idx[name_p]]
+end
 function getproperty(m::ContinuousTimeModel, sym::Symbol)
     if sym == :dobs
         return length(m.g)
diff --git a/tests/run_unit.jl b/tests/run_unit.jl
index 24cd7f7..251f35c 100644
--- a/tests/run_unit.jl
+++ b/tests/run_unit.jl
@@ -32,6 +32,7 @@ using Test
     
     
     @test include("unit/set_param.jl")
+    @test include("unit/set_x0.jl")
     @test include("unit/side_effects_1.jl")
     @test include("unit/simulate_available_models.jl")
     @test include("unit/simulate_sir.jl")
diff --git a/tests/unit/set_x0.jl b/tests/unit/set_x0.jl
new file mode 100644
index 0000000..993373c
--- /dev/null
+++ b/tests/unit/set_x0.jl
@@ -0,0 +1,21 @@
+
+using MarkovProcesses
+
+load_model("SIR")
+load_model("ER")
+
+test_all = true
+new_x0 = [92,10,2]
+set_x0!(SIR, new_x0)
+test_all = test_all && SIR.x0 == new_x0
+
+new_x0 = [1, 1, 4, 2]
+set_x0!(ER, new_x0)
+test_all = test_all && ER.x0 == new_x0 && SIR.x0 == [92, 10, 2]
+
+new_x0 = [10,10,2]
+set_x0!(SIR, new_x0)
+test_all = test_all && SIR.x0 == new_x0 && ER.x0 == [1, 1, 4, 2]
+
+return test_all
+
-- 
GitLab