From 33f82cc30313a23b4cee036ba1b86920ccaac48e Mon Sep 17 00:00:00 2001 From: Mahmoud Bentriou <mahmoud.bentriou@centralesupelec.fr> Date: Fri, 29 Jan 2021 00:23:01 +0100 Subject: [PATCH] performance improvement of euclidean distance automaton --- automata/euclidean_distance_automaton.jl | 11 +++++++---- tests/automata/euclidean_distance_single.jl | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/automata/euclidean_distance_automaton.jl b/automata/euclidean_distance_automaton.jl index 5fc21ae..2d3f38e 100644 --- a/automata/euclidean_distance_automaton.jl +++ b/automata/euclidean_distance_automaton.jl @@ -55,11 +55,14 @@ function create_euclidean_distance_automaton(m::ContinuousTimeModel, timeline::A # l1 => l1 # Defined below @everywhere $(func_name(:cc, :l1, :l1, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = - getfield(S, :values)[$(idx_var_t)] >= $(timeline)[convert(Int, getfield(S, :values)[$(idx_var_idx)])] + (tml = $(Tuple(timeline)); + tml_idx = tml[convert(Int, getfield(S, :values)[$(idx_var_idx)])]; + getfield(S, :values)[$(idx_var_t)] >= tml_idx) @everywhere $(func_name(:us, :l1, :l1, 1))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = - (setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_d)] + - (getfield(S, :values)[$(idx_var_n)] - $(observations)[convert(Int, getfield(S, :values)[$(idx_var_idx)])])^2, - $(idx_var_d)); + (y_obs = $(Tuple(observations)); + y_obs_idx = y_obs[convert(Int, getfield(S, :values)[$(idx_var_idx)])]; + setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_d)] + (getfield(S, :values)[$(idx_var_n)] - y_obs_idx)^2, + $(idx_var_d)); setindex!(getfield(S, :values), getfield(S, :values)[$(idx_var_idx)] + 1.0, $(idx_var_idx))) @everywhere $(func_name(:cc, :l1, :l1, 2))(S::StateLHA, x::Vector{Int}, p::Vector{Float64}) = true diff --git a/tests/automata/euclidean_distance_single.jl b/tests/automata/euclidean_distance_single.jl index c28ad52..f1923c5 100644 --- a/tests/automata/euclidean_distance_single.jl +++ b/tests/automata/euclidean_distance_single.jl @@ -11,7 +11,6 @@ y_obs = vectorize(simulate(SIR), :I, tml_obs) sync_SIR = SIR * create_euclidean_distance_automaton(SIR, tml_obs, y_obs, :I) σ = simulate(sync_SIR) test = euclidean_distance(σ, :I, tml_obs, y_obs) == σ.state_lha_end[:d] -@show test, euclidean_distance(σ, :I, tml_obs, y_obs), σ.state_lha_end[:d] return test -- GitLab