model.jl 4.23 KB
Newer Older
1

2
3
import StaticArrays: SVector

4
5
6
abstract type Model end 
abstract type DiscreteTimeModel <: Model end 

7
mutable struct ContinuousTimeModel <: Model
8
9
    d::Int # state space dim
    k::Int # parameter space dim
10
11
12
    map_var_idx::Dict{String,Int} # maps str to full state space
    _map_obs_var_idx::Dict{String,Int} # maps str to observed state space
    map_param_idx::Dict{String,Int} # maps str in parameter space
13
14
15
    l_name_transitions::Vector{String}
    p::Vector{Float64}
    x0::Vector{Int}
16
    t0::Float64
17
18
    f!::Function
    g::Vector{String} # of dimension dobs
19
    _g_idx::Vector{Int} # of dimension dobs
20
    is_absorbing::Function
21
    time_bound::Float64
22
    buffer_size::Int
23
24
end

25
function ContinuousTimeModel(d::Int, k::Int, map_var_idx::Dict, map_param_idx::Dict, l_name_transitions::Vector{String}, 
26
              p::Vector{Float64}, x0::Vector{Int}, t0::Float64, 
27
28
              f!::Function, is_absorbing::Function; 
              g::Vector{String} = keys(map_var_idx), time_bound::Float64 = Inf, buffer_size::Int = 10)
29
30
31
32
33
34
35
    dobs = length(g)
    _map_obs_var_idx = Dict()
    _g_idx = Vector{Int}(undef, dobs)
    for i = 1:dobs
        _g_idx[i] = map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
        _map_obs_var_idx[g[i]] = i
    end
36
37
38
39
40
41
42
43
  
    if length(methods(f!)) >= 2
        @warn "You have possibly redefined a function Model.f! used in a previously instantiated model."
    end
    if length(methods(is_absorbing)) >= 2
        @warn "You have possibly redefined a function Model.is_absorbing used in a previously instantiated model."
    end

44
    return ContinuousTimeModel(d, k, map_var_idx, _map_obs_var_idx, map_param_idx, l_name_transitions, p, x0, t0, f!, g, _g_idx, is_absorbing, time_bound, buffer_size)
45
46
47
end

function simulate(m::ContinuousTimeModel)
48
    # trajectory fields
49
    full_values = Matrix{Int}(undef, 1, m.d)
50
51
52
    full_values[1,:] = m.x0
    times = Float64[m.t0]
    transitions = Union{String,Nothing}[nothing]
53
54
    # values at time n
    n = 0
55
    xn = view(reshape(m.x0, 1, m.d), 1, :) # View for type stability
56
    tn = m.t0 
57
    # at time n+1
58
59
60
    mat_x = zeros(Int, m.buffer_size, m.d)
    l_t = zeros(Float64, m.buffer_size)
    l_tr = Vector{String}(undef, m.buffer_size)
61
    is_absorbing::Bool = m.is_absorbing(m.p,xn)
62
    while !is_absorbing && (tn <= m.time_bound)
63
64
65
66
        i = 0
        while i < m.buffer_size && !is_absorbing && (tn <= m.time_bound)
            i += 1
            m.f!(mat_x, l_t, l_tr, i, xn, tn, m.p)
67
            xn = view(mat_x, i, :)
68
            tn = l_t[i]
69
            is_absorbing = m.is_absorbing(m.p,xn)
70
        end
71
72
73
        full_values = vcat(full_values, view(mat_x, 1:i, :))
        append!(times, view(l_t, 1:i))
        append!(transitions,  view(l_tr, 1:i))
74
        n += i
75
        is_absorbing = m.is_absorbing(m.p,xn)
76
    end
Bentriou Mahmoud's avatar
Bentriou Mahmoud committed
77
78
    if is_bounded(m)
        if times[end] > m.time_bound
79
            full_values[end,:] = full_values[end-1,:]
Bentriou Mahmoud's avatar
Bentriou Mahmoud committed
80
81
            times[end] = m.time_bound
            transitions[end] = nothing
82
        else
83
            full_values = vcat(full_values, reshape(full_values[end,:], 1, m.d))
84
85
            push!(times, m.time_bound)
            push!(transitions, nothing)
Bentriou Mahmoud's avatar
Bentriou Mahmoud committed
86
87
        end
    end
88
    values = view(full_values, :, m._g_idx)
89
90
91
92
93
94
95
96
97
98
99
    return Trajectory(m, values, times, transitions)
end

function simulate(m::ContinuousTimeModel, n::Int)
    obs = ContinuousObservations(undef, n)
    for i = 1:n
        obs[i] = simulate(m)
    end
    return obs
end

100
101
102
103
104
105
106
107
108
109
110
111
112
function set_observed_var!(m::Model,g::Vector{String})
    dobs = length(g)
    _map_obs_var_idx = Dict()
    _g_idx = Vector{Int}(undef, dobs)
    for i = 1:dobs
        _g_idx[i] = m.map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
        _map_obs_var_idx[g[i]] = i
    end
    m.g = g
    m._g_idx = _g_idx
    m._map_obs_var_idx = _map_obs_var_idx
end

Bentriou Mahmoud's avatar
Bentriou Mahmoud committed
113
is_bounded(m::Model) = m.time_bound < Inf
114
function check_consistency(m::Model) end
115
function simulate(m::Model, n::Int; bound::Float64 = Inf)::AbstractObservations end
116
117
function set_param!(m::Model, p::Vector{Float64})::Nothing end
function get_param(m::Model)::Vector{Float64} end
118

119
get_module_path() = dirname(dirname(pathof(@__MODULE__)))
120
function load_model(name_model::String)
121
    include(get_module_path() * "/models/" * name_model * ".jl")
122
123
end