diff --git a/core/MarkovProcesses.jl b/core/MarkovProcesses.jl index 11c2b2daa4e205514c027c528482ab8949a9d391..e547640d4d41a7a0b2c7645e08e8786cf11d9074 100644 --- a/core/MarkovProcesses.jl +++ b/core/MarkovProcesses.jl @@ -24,13 +24,13 @@ export init_state, next_state!, read_trajectory export load_automaton, get_index, get_value, length_var, isaccepted # Model related methods -export simulate, set_param!, get_param, set_observed_var! +export simulate, set_param!, get_param, set_observed_var!, observe_all! export set_time_bound!, getproperty, draw! export isbounded, isaccepted, check_consistency export load_model, get_module_path # Utils -export get_module_path, cosmos_get_values +export get_module_path, cosmos_get_values, load_plots include("common.jl") diff --git a/core/model.jl b/core/model.jl index 3e363ca8603f730ddfe91ce46ed7181166f0ce1e..0a305c3fcb878f4fb2ddae6add7d099190ed1e8a 100644 --- a/core/model.jl +++ b/core/model.jl @@ -227,10 +227,10 @@ end function simulate(m::ContinuousTimeModel, n::Int) end -function set_observed_var!(m::Model,g::Vector{String}) +function set_observed_var!(m::Model, g::Vector{String}) dobs = length(g) - _map_obs_var_idx = Dict() - _g_idx = Vector{Int}(undef, dobs) + _map_obs_var_idx = Dict{String}{Int}() + _g_idx = zeros(Int, dobs) for i = 1:dobs _g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space ) _map_obs_var_idx[g[i]] = i @@ -240,6 +240,17 @@ function set_observed_var!(m::Model,g::Vector{String}) m._map_obs_var_idx = _map_obs_var_idx end +function observe_all!(m::Model) + g = Vector{String}(undef, m.d) + _g_idx = collect(1:m.d) + for var in keys(m.map_var_idx) + g[m.map_var_idx[var]] = var + end + m.g = g + m._g_idx = _g_idx + m._map_obs_var_idx = m.map_var_idx +end + isbounded(m::ContinuousTimeModel) = m.time_bound < Inf function check_consistency(m::ContinuousTimeModel) @assert m.d == length(m.map_var_idx) diff --git a/core/plots.jl b/core/plots.jl new file mode 100644 index 0000000000000000000000000000000000000000..b2092ef7f4468ef042a475dbc13e8369ed2d3341 --- /dev/null +++ b/core/plots.jl @@ -0,0 +1,46 @@ + +import Plots: plot, plot!, scatter! +import Plots: palette, display, png, close + +function plot(σ::AbstractTrajectory, vars::String...; filename::String = "", plot_transitions = false) + # Setup + palette_tr = palette(:default) + l_tr = unique(transitions(σ)) + map_tr_color(tr) = palette_tr[findfirst(x->x==tr, l_tr)] + to_plot = vars + if length(vars) == 0 + to_plot = get_obs_var(σ) + end + + # Plots + p = plot(title = "Trajectory", palette = :lightrainbow) + for var in to_plot + @assert var in get_obs_var(σ) + plot!(p, times(σ), σ[var], + xlabel = "Time", ylabel = "Number of species", + label = var, + linetype=:steppost) + end + if plot_transitions + for (i, var) in enumerate(to_plot) + for tr in l_tr + idx_tr = findall(x->x==tr, transitions(σ)) + label = (tr == nothing || i > 1) ? "" : tr + alpha = (tr == nothing) ? 0.0 : 0.5 + scatter!(p, times(σ)[idx_tr], σ[var][idx_tr], label=label, + markershape=:cross, markeralpha=alpha, + markersize = 2, + markercolor=palette_tr[findfirst(x->x==tr, l_tr)]) + end + end + end + if filename == "" + display(p) + else + png(p, filename) + close(p) + end +end + +export plot + diff --git a/core/utils.jl b/core/utils.jl index 5b46b3bcd1d53d10f1302766a4be967c6293dcda..c403635c85567a464600399d7d9eba0875c4aa6a 100644 --- a/core/utils.jl +++ b/core/utils.jl @@ -15,3 +15,5 @@ function cosmos_get_values(name_file::String) return dict_values end +load_plots() = include(get_module_path() * "/core/plots.jl") + diff --git a/tests/simulation/plot_pkg.jl b/tests/simulation/plot_pkg.jl new file mode 100644 index 0000000000000000000000000000000000000000..9ebbdd9abc01345d005d3fe5ab290c0566e99026 --- /dev/null +++ b/tests/simulation/plot_pkg.jl @@ -0,0 +1,9 @@ + +using MarkovProcesses + +load_model("SIR") +σ = simulate(SIR) + +load_plots() +plot(σ; plot_transitions = true) + diff --git a/tests/unit/observe_all.jl b/tests/unit/observe_all.jl new file mode 100644 index 0000000000000000000000000000000000000000..3608bfc4ca812d598b81c92992b5ae15422afac4 --- /dev/null +++ b/tests/unit/observe_all.jl @@ -0,0 +1,11 @@ + +using MarkovProcesses + +load_model("ER") +load_model("SIR") + +observe_all!(ER) +observe_all!(SIR) + +return ER.g == ["E", "S", "ER", "P"] && SIR.g == ["S", "I", "R"] +