From 047eb00b6e4bba5fb3af40453faf4c7e62fee277 Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Mon, 16 Nov 2020 15:55:49 +0100
Subject: [PATCH] Add function to change observed variables in model + tests

---
 core/MarkovProcesses.jl            |  2 +-
 core/model.jl                      | 13 +++++++++++++
 tests/run_unit.jl                  |  2 ++
 tests/unit/change_obs_var_sir.jl   | 17 +++++++++++++++++
 tests/unit/change_obs_var_sir_2.jl | 17 +++++++++++++++++
 5 files changed, 50 insertions(+), 1 deletion(-)
 create mode 100644 tests/unit/change_obs_var_sir.jl
 create mode 100644 tests/unit/change_obs_var_sir_2.jl

diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl
index faa787c..23c1828 100644
--- a/core/MarkovProcesses.jl
+++ b/core/MarkovProcesses.jl
@@ -3,7 +3,7 @@ module MarkovProcesses
 import Base: +, -, getfield, getindex
 
 export Model, ContinuousTimeModel, DiscreteTimeModel
-export simulate, set_param!, get_param
+export simulate, set_param!, get_param, set_observed_var!
 export is_bounded
 export load_model, get_module_path
 include("model.jl")
diff --git a/core/model.jl b/core/model.jl
index 60ed622..d37ad6d 100644
--- a/core/model.jl
+++ b/core/model.jl
@@ -78,6 +78,19 @@ function simulate(m::ContinuousTimeModel, n::Int)
     return obs
 end
 
+function set_observed_var!(m::Model,g::Vector{String})
+    dobs = length(g)
+    _map_obs_var_idx = Dict()
+    _g_idx = Vector{Int}(undef, dobs)
+    for i = 1:dobs
+        _g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
+        _map_obs_var_idx[g[i]] = i
+    end
+    m.g = g
+    m._g_idx = _g_idx
+    m._map_obs_var_idx = _map_obs_var_idx
+end
+
 is_bounded(m::Model) = m.time_bound < Inf
 function check_consistency(m::Model) end
 function simulate(m::Model, n::Int; bound::Float64 = Inf)::AbstractObservations end
diff --git a/tests/run_unit.jl b/tests/run_unit.jl
index e85df99..443aef4 100644
--- a/tests/run_unit.jl
+++ b/tests/run_unit.jl
@@ -7,5 +7,7 @@ using Test
     @test include("unit/simulate_sir.jl")
     @test include("unit/simulate_sir_bounded.jl")
     @test include("unit/simulate_er.jl")
+    @test include("unit/change_obs_var_sir.jl")
+    @test include("unit/change_obs_var_sir_2.jl")
 end
 
diff --git a/tests/unit/change_obs_var_sir.jl b/tests/unit/change_obs_var_sir.jl
new file mode 100644
index 0000000..125b01e
--- /dev/null
+++ b/tests/unit/change_obs_var_sir.jl
@@ -0,0 +1,17 @@
+
+using MarkovProcesses
+
+load_model("SIR")
+
+σ = simulate(SIR)
+set_observed_var!(SIR, ["I", "R"])
+
+d1 = Dict("S" => 1, "I" => 2, "R" => 3)
+d2 = Dict("I" => 1, "R" => 2)
+
+bool_test = SIR.g == ["I", "R"] && SIR._g_idx == [2,3] && 
+            SIR.map_var_idx == d1 && 
+            SIR._map_obs_var_idx == d2
+
+return bool_test
+
diff --git a/tests/unit/change_obs_var_sir_2.jl b/tests/unit/change_obs_var_sir_2.jl
new file mode 100644
index 0000000..e852889
--- /dev/null
+++ b/tests/unit/change_obs_var_sir_2.jl
@@ -0,0 +1,17 @@
+
+using MarkovProcesses
+
+load_model("SIR")
+
+σ = simulate(SIR)
+set_observed_var!(SIR, ["R", "S"])
+
+d1 = Dict("S" => 1, "I" => 2, "R" => 3)
+d2 = Dict("R" => 1, "S" => 2)
+
+bool_test = SIR.g == ["R", "S"] && SIR._g_idx == [3,1] && 
+            SIR.map_var_idx == d1 && 
+            SIR._map_obs_var_idx == d2
+
+return bool_test
+
-- 
GitLab