From 33f82cc30313a23b4cee036ba1b86920ccaac48e Mon Sep 17 00:00:00 2001
From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr>
Date: Fri, 29 Jan 2021 00:23:01 +0100
Subject: [PATCH] performance improvement of euclidean distance automaton

---
 automata/euclidean_distance_automaton.jl    | 11 +++++++----
 tests/automata/euclidean_distance_single.jl |  1 -
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/automata/euclidean_distance_automaton.jl b/automata/euclidean_distance_automaton.jl
index 5fc21ae..2d3f38e 100644
--- a/automata/euclidean_distance_automaton.jl
+++ b/automata/euclidean_distance_automaton.jl
@@ -55,11 +55,14 @@ function create_euclidean_distance_automaton(m::ContinuousTimeModel, timeline::A
         # l1 => l1
         # Defined below 
         @everywhere $(func_name(:cc, :l1, :l1, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
-        getfield(S, :values)[$(idx_var_t)] >= $(timeline)[convert(Int, getfield(S, :values)[$(idx_var_idx)])]
+        (tml = $(Tuple(timeline));
+         tml_idx = tml[convert(Int, getfield(S, :values)[$(idx_var_idx)])];
+         getfield(S, :values)[$(idx_var_t)] >= tml_idx)
         @everywhere $(func_name(:us, :l1, :l1, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) =
-        (setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_d)] + 
-                                        (getfield(S, :values)[$(idx_var_n)]  - $(observations)[convert(Int, getfield(S, :values)[$(idx_var_idx)])])^2, 
-                                        $(idx_var_d));
+        (y_obs = $(Tuple(observations));
+         y_obs_idx = y_obs[convert(Int, getfield(S, :values)[$(idx_var_idx)])];
+         setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_d)] + (getfield(S, :values)[$(idx_var_n)]  - y_obs_idx)^2, 
+                                         $(idx_var_d));
          setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_idx)] + 1.0, $(idx_var_idx)))
         
         @everywhere $(func_name(:cc, :l1, :l1, 2))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = true 
diff --git a/tests/automata/euclidean_distance_single.jl b/tests/automata/euclidean_distance_single.jl
index c28ad52..f1923c5 100644
--- a/tests/automata/euclidean_distance_single.jl
+++ b/tests/automata/euclidean_distance_single.jl
@@ -11,7 +11,6 @@ y_obs = vectorize(simulate(SIR), :I, tml_obs)
 sync_SIR = SIR * create_euclidean_distance_automaton(SIR, tml_obs, y_obs, :I)
 σ = simulate(sync_SIR)
 test = euclidean_distance(σ, :I, tml_obs, y_obs) == σ.state_lha_end[:d]
-@show test, euclidean_distance(σ, :I, tml_obs, y_obs), σ.state_lha_end[:d]
 
 return test
 
-- 
GitLab