From 169c93d29e353c0f22fe67009702124b32be6daf Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Sat, 5 Dec 2020 12:46:53 +0100
Subject: [PATCH] Improved performance of simulate methods: elimination of the
 unecessary call of isabsorbing() during the simulation.

Now the package is a bit more efficient than DiffEqJump.jl with the
Catalyst.jl interface which is encouraging, see bench/pkg/catalyst.jl.
---
 bench/pkg/catalyst.jl       |  32 +++++++
 core/biochemical_network.jl | 183 ++++++++++++++++++++++++++++++++++++
 core/model.jl               |  33 ++++---
 models/ER.jl                |   4 +
 models/SIR.jl               |   4 +
 5 files changed, 242 insertions(+), 14 deletions(-)
 create mode 100644 bench/pkg/catalyst.jl
 create mode 100644 core/biochemical_network.jl

diff --git a/bench/pkg/catalyst.jl b/bench/pkg/catalyst.jl
new file mode 100644
index 0000000..1f81b00
--- /dev/null
+++ b/bench/pkg/catalyst.jl
@@ -0,0 +1,32 @@
+
+using BenchmarkTools
+using MarkovProcesses
+using Catalyst
+using DiffEqJump
+
+load_model("ER")
+set_param!(ER, "k1", 0.2)
+set_param!(ER, "k2", 40.0)
+ER.buffer_size = 100
+ER.estim_min_states = 8000
+
+b_pkg = @benchmark simulate(ER)
+
+rs = @reaction_network begin
+  c1, S + E --> SE
+  c2, SE --> S + E
+  c3, SE --> P + E
+end c1 c2 c3
+p = (0.2,40.0,1.0)   # [c1,c2,c3]
+tspan = (0., 100.)
+u0 = [100., 100., 0., 0.]  # [S,E,SE,P]
+
+# solve JumpProblem
+dprob = DiscreteProblem(rs, u0, tspan, p)
+jprob = JumpProblem(rs, dprob, Direct())
+jsol = solve(jprob, SSAStepper())
+
+b_catalyst = @benchmark solve(jprob, SSAStepper())
+
+#plot(jsol,lw=2,title="Gillespie: Michaelis-Menten Enzyme Kinetics")
+
diff --git a/core/biochemical_network.jl b/core/biochemical_network.jl
new file mode 100644
index 0000000..c2592c4
--- /dev/null
+++ b/core/biochemical_network.jl
@@ -0,0 +1,183 @@
+
+using MacroTools
+
+function get_multiplicand_and_species(expr::Expr)
+    @assert expr.args[1] == :*
+    multiplicand = reduce(*, expr.args[2:(end-1)])
+    str_species = String(expr.args[end])
+    return (multiplicand, str_species)
+end
+get_multiplicand_and_species(sym::Symbol) = (1, String(sym))
+
+function get_str_propensity(propensity::Expr, dict_species::Dict, dict_params::Dict)
+    str_propensity = ""
+    for op in propensity.args[2:end]
+        str_op = String(op)
+        if haskey(dict_species, str_op)
+            str_propensity *= "xn[$(dict_species[str_op])] * "
+        elseif haskey(dict_params, str_op)
+            str_propensity *= "p[$(dict_params[str_op])] * "
+        else
+            str_propensity *= "$(str_op) * "
+        end
+    end
+    return str_propensity[1:(end-2)]
+end
+function get_str_propensity(propensity::Symbol, dict_species::Dict, dict_params::Dict)
+    str_propensity = String(propensity)
+    if haskey(dict_species, str_propensity)
+        str_propensity = "xn[$(dict_species[str_propensity])]"
+    elseif haskey(dict_params, str_propensity)
+        str_propensity = "p[$(dict_params[str_propensity])]"
+    else
+        str_propensity = "$(str_propensity)"
+    end
+    return str_propensity
+end
+
+macro biochemical_network(expr_name,expr_network)
+    transitions = String[]
+    dict_species = Dict{String,Int}()
+    dict_params = Dict{String,Int}()
+    dim_state = 0
+    dim_params = 0
+    list_expr_reactions = Any[]
+    # First we detect all of the species
+    for expr_reaction in expr_network.args
+        local isreaction = @capture(expr_reaction, TR_: (reactants_ => products_, propensity_))
+        if isreaction
+            push!(list_expr_reactions, expr_reaction)
+            push!(transitions, String(TR))
+            # Parsing reactants, products
+            for reaction_part in [reactants, products]
+                # If there's several species interacting / produced
+                if typeof(reaction_part) <: Expr && reaction_part.args[1] == :+ 
+                    for operand in reaction_part.args[2:end]
+                        mult, str_species = get_multiplicand_and_species(operand)
+                        if !haskey(dict_species, str_species)
+                            dim_state += 1
+                            dict_species[str_species] = dim_state
+                        end
+                    end
+                else
+                    mult, str_species = get_multiplicand_and_species(reaction_part)
+                    if !haskey(dict_species, str_species)
+                        dim_state += 1
+                        dict_species[str_species] = dim_state
+                    end
+                end
+            end
+        end
+        if !isreaction && !(typeof(expr_reaction) <: LineNumberNode)
+            error("Error in an expression describing a reaction")
+        end
+    end
+    list_species = [species for species in keys(dict_species)]
+    # Then we detect parameters in propensity expressions
+    # Parameters are the symbols that are not species (at this point we know all of the involved species)
+    for expr_reaction in list_expr_reactions
+        local isreaction = @capture(expr_reaction, TR_: (reactants_ => products_, propensity_))
+        if typeof(propensity) <: Expr 
+            @assert propensity.args[1] == :* "Only product of species/params/constants are allowed in propensity"
+            for operand in propensity.args[2:end]
+                if typeof(operand) <: Symbol
+                    str_op = String(operand)
+                    # If it's not a species, it's a parameter
+                    if !(str_op in list_species) && !haskey(dict_params, str_op)
+                        dim_params += 1
+                        dict_params[str_op] = dim_params
+                    end
+                end
+            end
+        elseif typeof(propensity) <: Symbol
+            str_op = String(propensity)
+            if !(str_op in list_species) && !haskey(dict_params, str_op)
+                dim_params += 1
+                dict_params[str_op] = dim_params
+            end
+        end
+        if !isreaction && !(typeof(expr_reaction) <: LineNumberNode)
+            error("Error in an expression describing a reaction")
+        end
+    end
+    # Let's write some lines that creates the function f! (step of a simulation) for this biochemical network
+    nbr_rand = rand(1:1000)
+    nbr_reactions = length(list_expr_reactions)
+    basename_func = "$(replace(expr_name, ' '=>'_'))_$(nbr_rand)"
+    expr_model_f! = "function $(basename_func)_f!(xnplus1::Vector{Int}, l_t::Vector{Float64}, l_tr::Vector{Union{Nothing,String}}, xn::Vector{Int}, tn::Float64, p::Vector{Float64})\n\t"
+    # Computation of nu and propensity functions in f!
+    str_l_a = "l_a = ("
+    str_test_isabsorbing = "("
+    l_nu = [zeros(Int, dim_state) for i = 1:nbr_reactions]
+    for (i, expr_reaction) in enumerate(list_expr_reactions)
+        local isreaction = @capture(expr_reaction, TR_: (reactants_ => products_, propensity_))
+        # Writing of propensities function
+        str_propensity = get_str_propensity(propensity, dict_species, dict_params)
+        expr_model_f! *= "a$(i) = " * str_propensity * "\n\t"
+        # Anticipating the write of the function isabsorbing
+        str_test_isabsorbing *= str_propensity * "+"
+        # Update the nu of the i-th reaction 
+        nu = l_nu[i]
+        if typeof(reactants) <: Expr && reactants.args[1] == :+ 
+            for operand in reactants.args[2:end]
+                mult, str_species = get_multiplicand_and_species(operand)
+                nu[dict_species[str_species]] -= mult
+            end
+        else
+            mult, str_species = get_multiplicand_and_species(reactants)
+            nu[dict_species[str_species]] -= mult
+        end
+        if typeof(products) <: Expr && products.args[1] == :+ 
+            for operand in products.args[2:end]
+                mult, str_species = get_multiplicand_and_species(operand)
+                nu[dict_species[str_species]] += mult
+            end
+        else
+            mult, str_species = get_multiplicand_and_species(products)
+            nu[dict_species[str_species]] += mult
+        end
+        expr_model_f! *= "nu_$i = $(Tuple(nu))\n\t"
+        # Anticipating the line l_a = (..)
+        str_l_a *= "a$(i), "
+    end
+    str_test_isabsorbing = str_test_isabsorbing[1:(end-2)] * ")"
+    str_l_a = str_l_a[1:(end-2)] * ")\n\t"
+    expr_model_f! *= str_l_a
+    expr_model_f! *= "asum = sum(l_a)\n\t"
+    expr_model_f! *= "if asum == 0.0\n\t\t"
+    expr_model_f! *= "copyto!(xnplus1, xn)\n\t\t"
+    expr_model_f! *= "return nothing\n\t"
+    expr_model_f! *= "end\n\t"
+    # Computation of array of transitions
+    expr_model_f! *= "l_nu = (" * reduce(*, ["nu_$i, " for i = 1:nbr_reactions])[1:(end-2)] * ")\n\t"
+    expr_model_f! *= "l_str_R = $(Tuple(transitions))\n\t"
+    # Simulation of the reaction
+    expr_model_f! *= "u1 = rand()\n\t"
+    expr_model_f! *= "u2 = rand()\n\t"
+    expr_model_f! *= "tau = - log(u1) / asum\n\t"
+    expr_model_f! *= "b_inf = 0.0\n\t" 
+    expr_model_f! *= "b_sup = a1\n\t" 
+    expr_model_f! *= "reaction = 0\n\n\t" 
+    expr_model_f! *= "for i = 1:$(nbr_reactions)\n\t\t"
+    expr_model_f! *= "if b_inf < asum*u2 < b_sup\n\t\t\t"
+    expr_model_f! *= "reaction = i\n\t\t\t"
+    expr_model_f! *= "break\n\t\t"
+    expr_model_f! *= "end\n\t\t"
+    expr_model_f! *= "b_inf += l_a[i]\n\t\t"
+    expr_model_f! *= "b_sup += l_a[i+1]\n\t"
+    expr_model_f! *= "end\n\t"
+    expr_model_f! *= "nu = l_nu[reaction]\n\t"
+    expr_model_f! *= "for i = 1:$(dim_state)\n\t\t"
+    expr_model_f! *= "xnplus1[i] = xn[i]+nu[i]\n\t"
+    expr_model_f! *= "end\n\t"
+    expr_model_f! *= "l_t[1] = tn + tau\n\t"
+    expr_model_f! *= "l_tr[1] = l_str_R[reaction]\n"
+    expr_model_f! *= "end\n"
+    
+    expr_model_isabsorbing = "isabsorbing_$(basename_func)(p::Vector{Float64},xn::Vector{Int}) = $(str_test_isabsorbing) === 0.0"
+    model_f! = eval(Meta.parse(expr_model_f!))
+    model_isabsorbing = eval(Meta.parse(expr_model_isabsorbing))
+    return :(ContinuousTimeModel($dim_state, $dim_params, $dict_species, $dict_params, $transitions, 
+                                 $(zeros(dim_params)), $(zeros(Int, dim_state)), 0.0, $model_f!, $model_isabsorbing; g = $list_species))
+end
+
diff --git a/core/model.jl b/core/model.jl
index 894a7a3..26136ba 100644
--- a/core/model.jl
+++ b/core/model.jl
@@ -67,16 +67,13 @@ function simulate(m::ContinuousTimeModel; p::Union{Nothing,AbstractVector{Float6
     for i = 2:m.estim_min_states
         m.f!(vec_x, l_t, l_tr, xn, tn, p_sim)
         tn = l_t[1]
-        if tn > m.time_bound
+        if tn > m.time_bound || vec_x == xn
+            isabsorbing = (vec_x == xn)
             break
         end
         n += 1
         copyto!(xn, vec_x)
         _update_values!(full_values, times, transitions, xn, tn, l_tr[1], i)
-        isabsorbing = m.isabsorbing(p_sim,xn)
-        if isabsorbing 
-            break
-        end
     end
     # If simulation ended before the estimation of states
     if n < m.estim_min_states
@@ -97,17 +94,18 @@ function simulate(m::ContinuousTimeModel; p::Union{Nothing,AbstractVector{Float6
             i += 1
             m.f!(vec_x, l_t, l_tr, xn, tn, p_sim)
             tn = l_t[1]
-            if tn > m.time_bound
+            if tn > m.time_bound 
+                i -= 1
+                break
+            end
+            if vec_x == xn
+                isabsorbing = true
                 i -= 1
                 break
             end
             copyto!(xn, vec_x)
             _update_values!(full_values, times, transitions, 
                             xn, tn, l_tr[1], m.estim_min_states+size_tmp+i)
-            isabsorbing = m.isabsorbing(p_sim,xn)
-            if isabsorbing 
-                break
-            end
         end
         # If simulation ended before the end of buffer
         if i < m.buffer_size
@@ -172,7 +170,8 @@ function simulate(product::SynchronizedModel; p::Union{Nothing,AbstractVector{Fl
     for i = 2:m.estim_min_states
         m.f!(vec_x, l_t, l_tr, xn, tn, p_sim)
         tn = l_t[1]
-        if tn > m.time_bound
+        if tn > m.time_bound || vec_x == xn
+            isabsorbing = (vec_x == xn)
             break
         end
         n += 1
@@ -181,7 +180,6 @@ function simulate(product::SynchronizedModel; p::Union{Nothing,AbstractVector{Fl
         next_state!(Snplus1, A, xn, tn, tr_n, Sn; verbose = verbose)
         _update_values!(full_values, times, transitions, xn, tn, tr_n, i)
         copyto!(Sn, Snplus1)
-        isabsorbing = m.isabsorbing(p_sim,xn)
         isacceptedLHA = isaccepted(Snplus1)
         if isabsorbing || isacceptedLHA 
             break
@@ -214,13 +212,17 @@ function simulate(product::SynchronizedModel; p::Union{Nothing,AbstractVector{Fl
                 i -= 1
                 break
             end
+            if vec_x == xn
+                isabsorbing = true
+                i -= 1
+                break
+            end
             copyto!(xn, vec_x)
             tr_n = l_tr[1]
             next_state!(Snplus1, A, xn, tn, tr_n, Sn; verbose = verbose)
             _update_values!(full_values, times, transitions, 
                             xn, tn, tr_n, m.estim_min_states+size_tmp+i)
             copyto!(Sn, Snplus1)
-            isabsorbing = m.isabsorbing(p_sim,xn)
             isacceptedLHA = isaccepted(Snplus1)
             if isabsorbing || isacceptedLHA
                 break
@@ -288,11 +290,14 @@ function volatile_simulate(product::SynchronizedModel;
             i -= 1
             break
         end
+        if vec_x == xn
+            isabsorbing = true
+            break
+        end
         copyto!(xn, vec_x)
         tr_n = l_tr[1]
         next_state!(Snplus1, A, xn, tn, tr_n, Sn; verbose = verbose)
         copyto!(Sn, Snplus1)
-        isabsorbing = m.isabsorbing(p_sim,xn)
         isacceptedLHA = isaccepted(Snplus1)
         n += 1
     end
diff --git a/models/ER.jl b/models/ER.jl
index f8a6fe7..11e703b 100644
--- a/models/ER.jl
+++ b/models/ER.jl
@@ -16,6 +16,10 @@ function ER_f!(xnplus1::Vector{Int}, l_t::Vector{Float64}, l_tr::Vector{Union{No
     a3 = p[3] * xn[3]
     l_a = (a1, a2, a3)
     asum = sum(l_a)
+    if asum == 0.0
+        copyto!(xnplus1, xn)
+        return nothing
+    end
     nu_1 = (-1, -1, 1, 0)
     nu_2 = (1, 1, -1, 0)
     nu_3 = (1, 0, -1, 1) 
diff --git a/models/SIR.jl b/models/SIR.jl
index 82ea26f..93fbbb9 100644
--- a/models/SIR.jl
+++ b/models/SIR.jl
@@ -15,6 +15,10 @@ function SIR_f!(xnplus1::Vector{Int}, l_t::Vector{Float64}, l_tr::Vector{Union{N
     a2 = p[2] * xn[2]
     l_a = (a1, a2)
     asum = sum(l_a)
+    if asum == 0.0
+        copyto!(xnplus1, xn)
+        return nothing
+    end
     # column-major order
     nu_1 = (-1, 1, 0)
     nu_2 = (0, -1, 1)
-- 
GitLab