From d690d6b980a4632c5f69a4bcea3a051517a9a861 Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Mon, 16 Nov 2020 00:41:39 +0100
Subject: [PATCH] Improvement of memory allocation. In fine I give the Abstract
 and Static arrays types up.

---
 core/model.jl                      | 44 ++++++++++++++++++------------
 models/SIR.jl                      | 25 +++++++++--------
 tests/unit/simulate_sir_bounded.jl |  2 +-
 3 files changed, 41 insertions(+), 30 deletions(-)

diff --git a/core/model.jl b/core/model.jl
index 7c428ab..6bea159 100644
--- a/core/model.jl
+++ b/core/model.jl
@@ -11,20 +11,20 @@ mutable struct CTMC <: ContinuousTimeModel
     map_var_idx::Dict # maps str to full state space
     _map_obs_var_idx::Dict # maps str to observed state space
     map_param_idx::Dict # maps str in parameter space
-    l_name_transitions::AbstractVector{String}
-    p::AbstractVector{Float64}
-    x0::AbstractVector{Int}
+    l_name_transitions::Vector{String}
+    p::Vector{Float64}
+    x0::Vector{Int}
     t0::Float64
-    f::Function
-    g::AbstractVector{String} # of dimension dobs
+    f!::Function
+    g::Vector{String} # of dimension dobs
     _g_idx::Vector{Int} # of dimension dobs
     is_absorbing::Function
     time_bound::Float64
 end
 
-function CTMC(d::Int, k::Int, map_var_idx::Dict, map_param_idx::Dict, l_name_transitions::AbstractVector{String}, 
-              p::AbstractVector{Float64}, x0::AbstractVector{Int}, t0::Float64, 
-              f::Function, is_absorbing::Function; g::AbstractVector{String} = keys(map_var_idx), time_bound::Float64 = Inf)
+function CTMC(d::Int, k::Int, map_var_idx::Dict, map_param_idx::Dict, l_name_transitions::Vector{String}, 
+              p::Vector{Float64}, x0::Vector{Int}, t0::Float64, 
+              f!::Function, is_absorbing::Function; g::Vector{String} = keys(map_var_idx), time_bound::Float64 = Inf)
     dobs = length(g)
     _map_obs_var_idx = Dict()
     _g_idx = Vector{Int}(undef, dobs)
@@ -32,23 +32,31 @@ function CTMC(d::Int, k::Int, map_var_idx::Dict, map_param_idx::Dict, l_name_tra
         _g_idx[i] = map_var_idx[g[i]] # = ( (g[i] = i-th obs var)::String => idx in state space )
         _map_obs_var_idx[g[i]] = i
     end
-    return CTMC(d, k, map_var_idx, _map_obs_var_idx, map_param_idx, l_name_transitions, p, x0, t0, f, g, _g_idx, is_absorbing, time_bound)
+    return CTMC(d, k, map_var_idx, _map_obs_var_idx, map_param_idx, l_name_transitions, p, x0, t0, f!, g, _g_idx, is_absorbing, time_bound)
 end
 
 function simulate(m::ContinuousTimeModel)
+    # trajectory fields
     full_values = zeros(m.d, 0)
     times = zeros(0)
-    transitions = Vector{Union{String,Nothing}}(undef, 0)
+    transitions = Vector{Union{String,Nothing}}(undef,0)
+    # values at time n
+    n = 0
     xn = m.x0
     tn = m.t0 
-    n = 0
-    while !m.is_absorbing(m.p,xn) && tn <= m.time_bound
-        xnplus1, tnplus1, tr = f(xn, tn, m.p)
+    tr = [""]
+    # at time n+1
+    xnplus1 = zeros(Int, m.d)
+    tnplus1 = zeros(Float64, 1)
+    is_absorbing = (m.is_absorbing(m.p,xn))::Bool
+    while !is_absorbing && (tn <= m.time_bound)
+        m.f!(xnplus1, tnplus1, tr, xn, tn, m.p)
         full_values = hcat(full_values, xnplus1)
-        push!(times, tnplus1)
-        push!(transitions, tr)
-        xn, tn = xnplus1, tnplus1
+        push!(times, tnplus1[1])
+        push!(transitions, tr[1])
+        xn, tn = xnplus1, tnplus1[1]
         n += 1
+        is_absorbing = m.is_absorbing(m.p,xn)::Bool
     end
     values = full_values[m._g_idx,:] 
     if is_bounded(m)
@@ -73,8 +81,8 @@ end
 is_bounded(m::Model) = m.time_bound < Inf
 function check_consistency(m::Model) end
 function simulate(m::Model, n::Int; bound::Float64 = Inf)::AbstractObservations end
-function set_param!(m::Model, p::AbstractVector{Float64})::Nothing end
-function get_param(m::Model)::AbstractVector{Float64} end
+function set_param!(m::Model, p::Vector{Float64})::Nothing end
+function get_param(m::Model)::Vector{Float64} end
 
 function load_model(name_model::String)
     include(pathof(@__MODULE__) * "/../../models/" * name_model * ".jl")
diff --git a/models/SIR.jl b/models/SIR.jl
index 451c9c2..476e3da 100644
--- a/models/SIR.jl
+++ b/models/SIR.jl
@@ -6,10 +6,11 @@ k=2
 dict_var = Dict("S" => 1, "I" => 2, "R" => 3)
 dict_p = Dict("ki" => 1, "kr" => 2)
 l_tr = ["R1","R2"]
-p = SVector(0.0012, 0.05)
-x0 = SVector(95, 5, 0)
+p = [0.0012, 0.05]
+x0 = [95, 5, 0]
 t0 = 0.0
-function f(xn::SVector{3, Int}, tn::Float64, p::SVector{2, Float64})
+function f!(xnplus1::Vector{Int}, tnplus1::Vector{Float64}, tr::Vector{String}, 
+           xn::Vector{Int}, tn::Float64, p::Vector{Float64})
     a1 = p[1] * xn[1] * xn[2]
     a2 = p[2] * xn[2]
     l_a = SVector(a1, a2)
@@ -33,16 +34,18 @@ function f(xn::SVector{3, Int}, tn::Float64, p::SVector{2, Float64})
         b_sup += l_a[i+1]
     end
  
-    nu = l_nu[:,reaction]
-    xnplus1 = SVector(xn[1]+nu[1], xn[2]+nu[2], xn[3]+nu[3])
-    tnplus1 = tn + tau
-    transition = "R$(reaction)"
+    nu = @view l_nu[:,reaction] # macro for avoiding a copy
+    xnplus1[1] = xn[1]+nu[1]
+    xnplus1[2] = xn[2]+nu[2]
+    xnplus1[3] = xn[3]+nu[3]
+    tnplus1[1] = tn + tau
+    tr[1] = "R$(reaction)"
 
-    return xnplus1, tnplus1, transition
+    #return xnplus1, tnplus1, transition
 end
-is_absorbing_sir(p::SVector{2, Float64}, xn::SVector{3, Int}) = (p[1]*xn[1]*xn[2] + p[2]*xn[2]) == 0.0
-g = SVector("I")
+is_absorbing_sir(p::Vector{Float64}, xn::Vector{Int}) = (p[1]*xn[1]*xn[2] + p[2]*xn[2]) === 0.0
+g = ["I"]
 
-SIR = CTMC(d,k,dict_var,dict_p,l_tr,p,x0,t0,f,is_absorbing_sir; g=g)
+SIR = CTMC(d,k,dict_var,dict_p,l_tr,p,x0,t0,f!,is_absorbing_sir; g=g)
 export SIR
 
diff --git a/tests/unit/simulate_sir_bounded.jl b/tests/unit/simulate_sir_bounded.jl
index c0b68e6..0a7a10f 100644
--- a/tests/unit/simulate_sir_bounded.jl
+++ b/tests/unit/simulate_sir_bounded.jl
@@ -11,7 +11,7 @@ d2 = Dict("I" => 1)
 
 bool_test = SIR.g == ["I"] && SIR._g_idx == [2] && 
             SIR.map_var_idx == d1 && 
-            SIR._map_obs_var_idx == d2
+            SIR._map_obs_var_idx == d2 && σ["times"][end] <= SIR.time_bound
 
 return bool_test
 
-- 
GitLab