From 987f9441d3b7bc97896bacc07184c65f75b27e13 Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Thu, 26 Nov 2020 01:32:07 +0100
Subject: [PATCH] Add of plots of trajectory + observe_all! + tests

---
 core/MarkovProcesses.jl      |  4 ++--
 core/model.jl                | 17 ++++++++++---
 core/plots.jl                | 46 ++++++++++++++++++++++++++++++++++++
 core/utils.jl                |  2 ++
 tests/simulation/plot_pkg.jl |  9 +++++++
 tests/unit/observe_all.jl    | 11 +++++++++
 6 files changed, 84 insertions(+), 5 deletions(-)
 create mode 100644 core/plots.jl
 create mode 100644 tests/simulation/plot_pkg.jl
 create mode 100644 tests/unit/observe_all.jl

diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl
index 11c2b2d..e547640 100644
--- a/core/MarkovProcesses.jl
+++ b/core/MarkovProcesses.jl
@@ -24,13 +24,13 @@ export init_state, next_state!, read_trajectory
 export load_automaton, get_index, get_value, length_var, isaccepted
 
 # Model related methods
-export simulate, set_param!, get_param, set_observed_var!
+export simulate, set_param!, get_param, set_observed_var!, observe_all!
 export set_time_bound!, getproperty, draw!
 export isbounded, isaccepted, check_consistency
 export load_model, get_module_path
 
 # Utils
-export get_module_path, cosmos_get_values
+export get_module_path, cosmos_get_values, load_plots
 
 include("common.jl")
 
diff --git a/core/model.jl b/core/model.jl
index 3e363ca..0a305c3 100644
--- a/core/model.jl
+++ b/core/model.jl
@@ -227,10 +227,10 @@ end
 function simulate(m::ContinuousTimeModel, n::Int)
 end
 
-function set_observed_var!(m::Model,g::Vector{String})
+function set_observed_var!(m::Model, g::Vector{String})
     dobs = length(g)
-    _map_obs_var_idx = Dict()
-    _g_idx = Vector{Int}(undef, dobs)
+    _map_obs_var_idx = Dict{String}{Int}()
+    _g_idx = zeros(Int, dobs)
     for i = 1:dobs
         _g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
         _map_obs_var_idx[g[i]] = i
@@ -240,6 +240,17 @@ function set_observed_var!(m::Model,g::Vector{String})
     m._map_obs_var_idx = _map_obs_var_idx
 end
 
+function observe_all!(m::Model)
+    g = Vector{String}(undef, m.d)
+    _g_idx = collect(1:m.d)
+    for var in keys(m.map_var_idx)
+        g[m.map_var_idx[var]] = var
+    end
+    m.g = g
+    m._g_idx = _g_idx
+    m._map_obs_var_idx = m.map_var_idx
+end
+
 isbounded(m::ContinuousTimeModel) = m.time_bound < Inf
 function check_consistency(m::ContinuousTimeModel) 
     @assert m.d == length(m.map_var_idx) 
diff --git a/core/plots.jl b/core/plots.jl
new file mode 100644
index 0000000..b2092ef
--- /dev/null
+++ b/core/plots.jl
@@ -0,0 +1,46 @@
+
+import Plots: plot, plot!, scatter!
+import Plots: palette, display, png, close
+
+function plot(σ::AbstractTrajectory, vars::String...; filename::String = "", plot_transitions = false)
+    # Setup 
+    palette_tr = palette(:default)
+    l_tr = unique(transitions(σ))
+    map_tr_color(tr) = palette_tr[findfirst(x->x==tr, l_tr)]
+    to_plot = vars
+    if length(vars) ==  0
+        to_plot = get_obs_var(σ)
+    end
+    
+    # Plots
+    p = plot(title = "Trajectory", palette = :lightrainbow)
+    for var in to_plot
+        @assert var in get_obs_var(σ) 
+        plot!(p, times(σ), σ[var], 
+              xlabel = "Time", ylabel = "Number of species",
+              label = var,
+              linetype=:steppost)
+    end
+    if plot_transitions
+        for (i, var) in enumerate(to_plot)
+            for tr in l_tr
+                idx_tr = findall(x->x==tr, transitions(σ))
+                label = (tr == nothing || i > 1) ? "" : tr
+                alpha = (tr == nothing) ? 0.0 : 0.5
+                scatter!(p, times(σ)[idx_tr], σ[var][idx_tr], label=label, 
+                         markershape=:cross, markeralpha=alpha, 
+                         markersize = 2,
+                         markercolor=palette_tr[findfirst(x->x==tr, l_tr)])
+            end
+        end
+    end
+    if filename == ""
+        display(p)
+    else
+        png(p, filename)
+        close(p)
+    end
+end
+
+export plot
+
diff --git a/core/utils.jl b/core/utils.jl
index 5b46b3b..c403635 100644
--- a/core/utils.jl
+++ b/core/utils.jl
@@ -15,3 +15,5 @@ function cosmos_get_values(name_file::String)
     return dict_values
 end
 
+load_plots() = include(get_module_path() * "/core/plots.jl")
+
diff --git a/tests/simulation/plot_pkg.jl b/tests/simulation/plot_pkg.jl
new file mode 100644
index 0000000..9ebbdd9
--- /dev/null
+++ b/tests/simulation/plot_pkg.jl
@@ -0,0 +1,9 @@
+
+using MarkovProcesses
+
+load_model("SIR")
+σ = simulate(SIR)
+
+load_plots()
+plot(σ; plot_transitions = true)
+
diff --git a/tests/unit/observe_all.jl b/tests/unit/observe_all.jl
new file mode 100644
index 0000000..3608bfc
--- /dev/null
+++ b/tests/unit/observe_all.jl
@@ -0,0 +1,11 @@
+
+using MarkovProcesses
+
+load_model("ER")
+load_model("SIR")
+
+observe_all!(ER)
+observe_all!(SIR)
+
+return ER.g == ["E", "S", "ER", "P"] && SIR.g == ["S", "I", "R"]
+
-- 
GitLab