From 987f9441d3b7bc97896bacc07184c65f75b27e13 Mon Sep 17 00:00:00 2001 From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr> Date: Thu, 26 Nov 2020 01:32:07 +0100 Subject: [PATCH] Add of plots of trajectory + observe_all! + tests --- core/MarkovProcesses.jl | 4 ++-- core/model.jl | 17 ++++++++++--- core/plots.jl | 46 ++++++++++++++++++++++++++++++++++++ core/utils.jl | 2 ++ tests/simulation/plot_pkg.jl | 9 +++++++ tests/unit/observe_all.jl | 11 +++++++++ 6 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 core/plots.jl create mode 100644 tests/simulation/plot_pkg.jl create mode 100644 tests/unit/observe_all.jl diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl index 11c2b2d..e547640 100644 --- a/core/MarkovProcesses.jl +++ b/core/MarkovProcesses.jl @@ -24,13 +24,13 @@ export init_state, next_state!, read_trajectory export load_automaton, get_index, get_value, length_var, isaccepted # Model related methods -export simulate, set_param!, get_param, set_observed_var! +export simulate, set_param!, get_param, set_observed_var!, observe_all! export set_time_bound!, getproperty, draw! export isbounded, isaccepted, check_consistency export load_model, get_module_path # Utils -export get_module_path, cosmos_get_values +export get_module_path, cosmos_get_values, load_plots include("common.jl") diff --git a/core/model.jl b/core/model.jl index 3e363ca..0a305c3 100644 --- a/core/model.jl +++ b/core/model.jl @@ -227,10 +227,10 @@ end function simulate(m::ContinuousTimeModel, n::Int) end -function set_observed_var!(m::Model,g::Vector{String}) +function set_observed_var!(m::Model, g::Vector{String}) dobs = length(g) - _map_obs_var_idx = Dict() - _g_idx = Vector{Int}(undef, dobs) + _map_obs_var_idx = Dict{String}{Int}() + _g_idx = zeros(Int, dobs) for i = 1:dobs _g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space ) _map_obs_var_idx[g[i]] = i @@ -240,6 +240,17 @@ function set_observed_var!(m::Model,g::Vector{String}) m._map_obs_var_idx = _map_obs_var_idx end +function observe_all!(m::Model) + g = Vector{String}(undef, m.d) + _g_idx = collect(1:m.d) + for var in keys(m.map_var_idx) + g[m.map_var_idx[var]] = var + end + m.g = g + m._g_idx = _g_idx + m._map_obs_var_idx = m.map_var_idx +end + isbounded(m::ContinuousTimeModel) = m.time_bound < Inf function check_consistency(m::ContinuousTimeModel) @assert m.d == length(m.map_var_idx) diff --git a/core/plots.jl b/core/plots.jl new file mode 100644 index 0000000..b2092ef --- /dev/null +++ b/core/plots.jl @@ -0,0 +1,46 @@ + +import Plots: plot, plot!, scatter! +import Plots: palette, display, png, close + +function plot(σ::AbstractTrajectory, vars::String...; filename::String = "", plot_transitions = false) + # Setup + palette_tr = palette(:default) + l_tr = unique(transitions(σ)) + map_tr_color(tr) = palette_tr[findfirst(x->x==tr, l_tr)] + to_plot = vars + if length(vars) == 0 + to_plot = get_obs_var(σ) + end + + # Plots + p = plot(title = "Trajectory", palette = :lightrainbow) + for var in to_plot + @assert var in get_obs_var(σ) + plot!(p, times(σ), σ[var], + xlabel = "Time", ylabel = "Number of species", + label = var, + linetype=:steppost) + end + if plot_transitions + for (i, var) in enumerate(to_plot) + for tr in l_tr + idx_tr = findall(x->x==tr, transitions(σ)) + label = (tr == nothing || i > 1) ? "" : tr + alpha = (tr == nothing) ? 0.0 : 0.5 + scatter!(p, times(σ)[idx_tr], σ[var][idx_tr], label=label, + markershape=:cross, markeralpha=alpha, + markersize = 2, + markercolor=palette_tr[findfirst(x->x==tr, l_tr)]) + end + end + end + if filename == "" + display(p) + else + png(p, filename) + close(p) + end +end + +export plot + diff --git a/core/utils.jl b/core/utils.jl index 5b46b3b..c403635 100644 --- a/core/utils.jl +++ b/core/utils.jl @@ -15,3 +15,5 @@ function cosmos_get_values(name_file::String) return dict_values end +load_plots() = include(get_module_path() * "/core/plots.jl") + diff --git a/tests/simulation/plot_pkg.jl b/tests/simulation/plot_pkg.jl new file mode 100644 index 0000000..9ebbdd9 --- /dev/null +++ b/tests/simulation/plot_pkg.jl @@ -0,0 +1,9 @@ + +using MarkovProcesses + +load_model("SIR") +σ = simulate(SIR) + +load_plots() +plot(σ; plot_transitions = true) + diff --git a/tests/unit/observe_all.jl b/tests/unit/observe_all.jl new file mode 100644 index 0000000..3608bfc --- /dev/null +++ b/tests/unit/observe_all.jl @@ -0,0 +1,11 @@ + +using MarkovProcesses + +load_model("ER") +load_model("SIR") + +observe_all!(ER) +observe_all!(SIR) + +return ER.g == ["E", "S", "ER", "P"] && SIR.g == ["S", "I", "R"] + -- GitLab