From f589101f7e36596fd84283e7ed77fd6ec5960eae Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Sun, 14 Feb 2021 10:40:57 +0100
Subject: [PATCH] big fix of period automaton, updates of mean and var were
 false

---
 automata/period_automaton.jl | 123 ++++++++++++++++++-----------------
 core/plots.jl                |  10 +--
 2 files changed, 68 insertions(+), 65 deletions(-)

diff --git a/automata/period_automaton.jl b/automata/period_automaton.jl
index 89c6ee2..9fd5e9a 100644
--- a/automata/period_automaton.jl
+++ b/automata/period_automaton.jl
@@ -2,9 +2,10 @@
 #(S[:var_tp] * (S[:n]-1) + (S[:mean_tp]-S[:tp])^2) / S[:n]
 
 @everywhere f_mean_tp(mean_tp::Float64, tp::Float64, n::Float64) =
-(mean_tp * n + tp) / (n+1)
+(mean_tp * (n-1) + tp) / n
 @everywhere g_var_tp(var_tp::Float64, mean_tp::Float64, tp::Float64, n::Float64) =
-((n-1)*var_tp + (tp-mean_tp)*(tp - f_mean_tp(mean_tp, tp, n+1))) / n
+(n-2)/(n-1)*var_tp + (tp-mean_tp)^2/n
+#((n-2)*var_tp + (tp-mean_tp)*(tp - f_mean_tp(mean_tp, tp, n))) / (n-1)
 
 @everywhere mean_error(mean_tp::Float64, var_tp::Float64, ref_mean_tp::Float64, ref_var_tp::Float64) =
 abs(mean_tp - ref_mean_tp)
@@ -59,7 +60,7 @@ function create_period_automaton(m::ContinuousTimeModel, L::Float64, H::Float64,
         map_edges[loc] = Dict{Location, Vector{Edge}}()
     end
  
-    get_idx_var(var::Symbol) = map_var_automaton_idx[var] 
+    to_idx(var::Symbol) = map_var_automaton_idx[var] 
     nbr_rand = rand(1:100000)
     basename_func = "$(replace(m.name, ' '=>'_'))_$(nbr_rand)"
     basename_func = replace(basename_func, '-'=>'_')
@@ -77,21 +78,21 @@ function create_period_automaton(m::ContinuousTimeModel, L::Float64, H::Float64,
 
         # * l0 => l0prime
         @everywhere $(func_name(:cc, :l0, :l0prime, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:t))) >= $initT
+        get_value(S, $(to_idx(:t))) >= $initT
         @everywhere $(func_name(:us, :l0, :l0prime, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("l0prime"));
-         setindex!(getfield(S, :values), Inf, $(get_idx_var(:d))))
+         setindex!(getfield(S, :values), Inf, $(to_idx(:d))))
 
         # * l0 => low
         @everywhere $(func_name(:cc, :l0, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:t))) >= $initT
+        get_value(S, $(to_idx(:t))) >= $initT
         @everywhere $(func_name(:us, :l0, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low"));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:t)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:top)));
-         setindex!(getfield(S, :values), -1, $(get_idx_var(:n)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:tp)));
-         setindex!(getfield(S, :values), Inf, $(get_idx_var(:d))))
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:t)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:top)));
+         setindex!(getfield(S, :values), -1, $(to_idx(:n)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:tp)));
+         setindex!(getfield(S, :values), Inf, $(to_idx(:d))))
 
         # l0prime
         # * l0prime => l0prime
@@ -105,124 +106,124 @@ function create_period_automaton(m::ContinuousTimeModel, L::Float64, H::Float64,
         true
         @everywhere $(func_name(:us, :l0prime, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low"));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:t)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:top)));
-         setindex!(getfield(S, :values), -1, $(get_idx_var(:n)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:tp))))
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:t)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:top)));
+         setindex!(getfield(S, :values), -1, $(to_idx(:n)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:tp))))
 
         # low 
         # * low => low
         @everywhere $(func_name(:cc, :low, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :low, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (nothing)
 
         # * low => mid 
         @everywhere $(func_name(:cc, :low, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :low, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("mid")))
 
         # * low => final
         @everywhere $(func_name(:cc, :low, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) == $N
+        get_value(S, $(to_idx(:n))) == $N
         @everywhere $(func_name(:us, :low, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("final"));
-         val_d = getfield(Main, $(Meta.quot(error_func)))(get_value(S, $(get_idx_var(:mean_tp))), 
-                                                          get_value(S, $(get_idx_var(:var_tp))), 
+         val_d = getfield(Main, $(Meta.quot(error_func)))(get_value(S, $(to_idx(:mean_tp))), 
+                                                          get_value(S, $(to_idx(:var_tp))), 
                                                           $(ref_mean_tp), $(ref_var_tp));
-         setindex!(getfield(S, :values), val_d, $(get_idx_var(:d))))
+         setindex!(getfield(S, :values), val_d, $(to_idx(:d))))
 
         # mid
         # * mid => mid
         @everywhere $(func_name(:cc, :mid, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :mid, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (nothing)
 
         # * mid => low 
         @everywhere $(func_name(:cc, :mid, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N &&
-        get_value(S, $(get_idx_var(:top))) == 0.0
+        get_value(S, $(to_idx(:n))) < $N &&
+        get_value(S, $(to_idx(:top))) == 0.0
         @everywhere $(func_name(:us, :mid, :low, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low")))
 
         @everywhere $(func_name(:cc, :mid, :low, 2))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) == -1.0 &&
-        get_value(S, $(get_idx_var(:top))) == 1.0
+        get_value(S, $(to_idx(:n))) == -1.0 &&
+        get_value(S, $(to_idx(:top))) == 1.0
         @everywhere $(func_name(:us, :mid, :low, 2))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low"));
-         setindex!(getfield(S, :values), get_value(S, $(get_idx_var(:n))) + 1, $(get_idx_var(:n)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:top)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:tp))))
+         setindex!(getfield(S, :values), get_value(S, $(to_idx(:n))) + 1, $(to_idx(:n)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:top)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:tp))))
 
         @everywhere $(func_name(:cc, :mid, :low, 3))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        (0 <= get_value(S, $(get_idx_var(:n))) <= 1) &&
-        get_value(S, $(get_idx_var(:top))) == 1.0
+        (get_value(S, $(to_idx(:n))) == 0.0) &&
+        get_value(S, $(to_idx(:top))) == 1.0
         @everywhere $(func_name(:us, :mid, :low, 3))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low"));
-         setindex!(getfield(S, :values), get_value(S, $(get_idx_var(:n))) + 1, $(get_idx_var(:n)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:top)));
-         setindex!(getfield(S, :values), f_mean_tp(get_value(S, $(get_idx_var(:mean_tp))), 
-                                                   get_value(S, $(get_idx_var(:tp))),
-                                                   get_value(S, $(get_idx_var(:n)))), $(get_idx_var(:mean_tp)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:tp))))
+         setindex!(getfield(S, :values), get_value(S, $(to_idx(:n))) + 1, $(to_idx(:n)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:top)));
+         setindex!(getfield(S, :values), f_mean_tp(get_value(S, $(to_idx(:mean_tp))), 
+                                                   get_value(S, $(to_idx(:tp))),
+                                                   get_value(S, $(to_idx(:n)))), $(to_idx(:mean_tp)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:tp))))
 
         @everywhere $(func_name(:cc, :mid, :low, 4))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        (2 <= get_value(S, $(get_idx_var(:n))) < $N) &&
-        get_value(S, $(get_idx_var(:top))) == 1.0
+        (1 <= get_value(S, $(to_idx(:n))) < $N) &&
+        get_value(S, $(to_idx(:top))) == 1.0
         @everywhere $(func_name(:us, :mid, :low, 4))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("low"));
-         setindex!(getfield(S, :values), get_value(S, $(get_idx_var(:n))) + 1, $(get_idx_var(:n)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:top)));
-         setindex!(getfield(S, :values), f_mean_tp(get_value(S, $(get_idx_var(:mean_tp))), 
-                                                   get_value(S, $(get_idx_var(:tp))),
-                                                   get_value(S, $(get_idx_var(:n)))), $(get_idx_var(:mean_tp)));
-         setindex!(getfield(S, :values), g_var_tp(get_value(S, $(get_idx_var(:var_tp))), 
-                                                   get_value(S, $(get_idx_var(:mean_tp))),
-                                                   get_value(S, $(get_idx_var(:tp))),
-                                                   get_value(S, $(get_idx_var(:n)))), $(get_idx_var(:var_tp)));
-         setindex!(getfield(S, :values), 0.0, $(get_idx_var(:tp))))
+         setindex!(getfield(S, :values), get_value(S, $(to_idx(:n))) + 1, $(to_idx(:n)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:top)));
+         setindex!(getfield(S, :values), g_var_tp(get_value(S, $(to_idx(:var_tp))), 
+                                                   get_value(S, $(to_idx(:mean_tp))),
+                                                   get_value(S, $(to_idx(:tp))),
+                                                   get_value(S, $(to_idx(:n)))), $(to_idx(:var_tp)));
+         setindex!(getfield(S, :values), f_mean_tp(get_value(S, $(to_idx(:mean_tp))), 
+                                                   get_value(S, $(to_idx(:tp))),
+                                                   get_value(S, $(to_idx(:n)))), $(to_idx(:mean_tp)));
+         setindex!(getfield(S, :values), 0.0, $(to_idx(:tp))))
 
         # * mid => high
         @everywhere $(func_name(:cc, :mid, :high, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :mid, :high, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("high"));
-         setindex!(getfield(S, :values), 1.0, $(get_idx_var(:top))))
+         setindex!(getfield(S, :values), 1.0, $(to_idx(:top))))
 
         # * mid => final
         @everywhere $(func_name(:cc, :mid, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) == $N
+        get_value(S, $(to_idx(:n))) == $N
         @everywhere $(func_name(:us, :mid, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("final"));
-         val_d = getfield(Main, Meta.quot($error_func))(get_value(S, $(get_idx_var(:mean_tp))),
-                                                        get_value(S, $(get_idx_var(:var_tp))),
+         val_d = getfield(Main, Meta.quot($error_func))(get_value(S, $(to_idx(:mean_tp))),
+                                                        get_value(S, $(to_idx(:var_tp))),
                                                         $(ref_mean_tp), $(ref_var_tp));
-         setindex!(getfield(S, :values), val_d, $(get_idx_var(:d))))
+         setindex!(getfield(S, :values), val_d, $(to_idx(:d))))
 
         # high 
         # * high => high
         @everywhere $(func_name(:cc, :high, :high, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :high, :high, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (nothing)
 
         # * high => mid
         @everywhere $(func_name(:cc, :high, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) < $N
+        get_value(S, $(to_idx(:n))) < $N
         @everywhere $(func_name(:us, :high, :mid, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("mid")))
 
         # * high => final
         @everywhere $(func_name(:cc, :high, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = 
-        get_value(S, $(get_idx_var(:n))) == $N
+        get_value(S, $(to_idx(:n))) == $N
         @everywhere $(func_name(:us, :high, :final, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
         (setfield!(S, :loc, Symbol("final"));
-         val_d = getfield(Main, Meta.quot($error_func))(get_value(S, $(get_idx_var(:mean_tp))),
-                                                        get_value(S, $(get_idx_var(:var_tp))),
+         val_d = getfield(Main, Meta.quot($error_func))(get_value(S, $(to_idx(:mean_tp))),
+                                                        get_value(S, $(to_idx(:var_tp))),
                                                         $(ref_mean_tp), $(ref_var_tp));
-         setindex!(getfield(S, :values), val_d, $(get_idx_var(:d))))
+         setindex!(getfield(S, :values), val_d, $(to_idx(:d))))
     end
     eval(meta_elementary_functions)
 
diff --git a/core/plots.jl b/core/plots.jl
index e23ba88..5b53101 100644
--- a/core/plots.jl
+++ b/core/plots.jl
@@ -78,7 +78,8 @@ function plot!(A::LHA; label::String = "")
 end
 
 # For tests purposes
-function plot_periodic_trajectory(A::LHA, σ::SynchronizedTrajectory, sym_obs::Symbol; verbose = false, annot_size::Float64 = 6.0, filename::String = "")
+function plot_periodic_trajectory(A::LHA, σ::SynchronizedTrajectory, sym_obs::Symbol; 
+                                  verbose = false, annot_size = 6, show_tp = false, filename::String = "")
     @assert sym_obs in get_obs_var(σ) "Variable is not observed in the model"
     @assert A.name in ["Period"]
     p_sim = (σ.m).p
@@ -120,12 +121,13 @@ function plot_periodic_trajectory(A::LHA, σ::SynchronizedTrajectory, sym_obs::S
                     label = label_state, xlabel = "Time", ylabel = "Species $sym_obs")
     end
     annot_n = [(times(σ)[idx_n[i]], σ[sym_obs][idx_n[i]] - 10, text("n = $(values_n[i])", annot_size, :top)) for i = eachindex(idx_n)]
-    annot_tp = [(times(σ)[idx_n[i]], σ[sym_obs][idx_n[i]] - 10, text("tp = $(round(values_tp[i], digits = 2))", annot_size, :bottom)) for i = eachindex(idx_n)]
-    annots = vcat(annot_n, annot_tp)
+    annot_tp = [(times(σ)[idx_n[i]], σ[sym_obs][idx_n[i]] - 10, text("tp = $(round(values_tp[i], digits = 5))", annot_size, :bottom)) for i = eachindex(idx_n)]
+    annots = (show_tp) ? vcat(annot_n, annot_tp) : annot_n
     scatter!(p, times(σ)[idx_n], σ[sym_obs][idx_n], annotations = annots,
                              markershape = :utriangle, markersize = 3, label = "n")
     hline!(p, [A.constants[:L], A.constants[:H]], label = "L, H", color = :grey; linestyle = :dot)
-    
+    @show values_n
+    @show values_tp    
     if filename == ""
         display(p)
     else
-- 
GitLab