Commit 987f9441 authored by Bentriou Mahmoud's avatar Bentriou Mahmoud
Browse files

Add of plots of trajectory + observe_all! + tests

parent 146e2afa
......@@ -24,13 +24,13 @@ export init_state, next_state!, read_trajectory
export load_automaton, get_index, get_value, length_var, isaccepted
# Model related methods
export simulate, set_param!, get_param, set_observed_var!
export simulate, set_param!, get_param, set_observed_var!, observe_all!
export set_time_bound!, getproperty, draw!
export isbounded, isaccepted, check_consistency
export load_model, get_module_path
# Utils
export get_module_path, cosmos_get_values
export get_module_path, cosmos_get_values, load_plots
include("common.jl")
......
......@@ -227,10 +227,10 @@ end
function simulate(m::ContinuousTimeModel, n::Int)
end
function set_observed_var!(m::Model,g::Vector{String})
function set_observed_var!(m::Model, g::Vector{String})
dobs = length(g)
_map_obs_var_idx = Dict()
_g_idx = Vector{Int}(undef, dobs)
_map_obs_var_idx = Dict{String}{Int}()
_g_idx = zeros(Int, dobs)
for i = 1:dobs
_g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
_map_obs_var_idx[g[i]] = i
......@@ -240,6 +240,17 @@ function set_observed_var!(m::Model,g::Vector{String})
m._map_obs_var_idx = _map_obs_var_idx
end
function observe_all!(m::Model)
g = Vector{String}(undef, m.d)
_g_idx = collect(1:m.d)
for var in keys(m.map_var_idx)
g[m.map_var_idx[var]] = var
end
m.g = g
m._g_idx = _g_idx
m._map_obs_var_idx = m.map_var_idx
end
isbounded(m::ContinuousTimeModel) = m.time_bound < Inf
function check_consistency(m::ContinuousTimeModel)
@assert m.d == length(m.map_var_idx)
......
import Plots: plot, plot!, scatter!
import Plots: palette, display, png, close
function plot(σ::AbstractTrajectory, vars::String...; filename::String = "", plot_transitions = false)
# Setup
palette_tr = palette(:default)
l_tr = unique(transitions(σ))
map_tr_color(tr) = palette_tr[findfirst(x->x==tr, l_tr)]
to_plot = vars
if length(vars) == 0
to_plot = get_obs_var(σ)
end
# Plots
p = plot(title = "Trajectory", palette = :lightrainbow)
for var in to_plot
@assert var in get_obs_var(σ)
plot!(p, times(σ), σ[var],
xlabel = "Time", ylabel = "Number of species",
label = var,
linetype=:steppost)
end
if plot_transitions
for (i, var) in enumerate(to_plot)
for tr in l_tr
idx_tr = findall(x->x==tr, transitions(σ))
label = (tr == nothing || i > 1) ? "" : tr
alpha = (tr == nothing) ? 0.0 : 0.5
scatter!(p, times(σ)[idx_tr], σ[var][idx_tr], label=label,
markershape=:cross, markeralpha=alpha,
markersize = 2,
markercolor=palette_tr[findfirst(x->x==tr, l_tr)])
end
end
end
if filename == ""
display(p)
else
png(p, filename)
close(p)
end
end
export plot
......@@ -15,3 +15,5 @@ function cosmos_get_values(name_file::String)
return dict_values
end
load_plots() = include(get_module_path() * "/core/plots.jl")
using MarkovProcesses
load_model("SIR")
σ = simulate(SIR)
load_plots()
plot(σ; plot_transitions = true)
using MarkovProcesses
load_model("ER")
load_model("SIR")
observe_all!(ER)
observe_all!(SIR)
return ER.g == ["E", "S", "ER", "P"] && SIR.g == ["S", "I", "R"]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment