Commit 8ac66b28 authored by Philippe Veber's avatar Philippe Veber
Browse files

tk/Mutsel_cpg_simulator: avoid recomputing most of the rate vectors

only recompute what is affected by the state change at some
position. Complexity is still quadratic from having to sample from all
positions, but the constant is about 300 times better than last commit.

> df <- data.frame(n = c(10000,13000,20000,23000,30000), t = c(5.03,7.53,16.84,21.58,36.12)) ; fit <- lm(t ~ I(n ^ 2), data = df) ; summary(fit)

Call:
lm(formula = t ~ I(n^2), data = df)

Residuals:
       1        2        3        4        5
 0.05330 -0.13314  0.18311 -0.09938 -0.00389

Coefficients:
             Estimate Std. Error t value Pr(>|t|)
(Intercept) 1.083e+00  1.161e-01   9.335   0.0026 **
I(n^2)      3.893e-08  2.286e-10 170.301 4.46e-07 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 0.146 on 3 degrees of freedom
Multiple R-squared:  0.9999,	Adjusted R-squared:  0.9999
F-statistic: 2.9e+04 on 1 and 3 DF,  p-value: 4.464e-07
parent 4b3d32db
......@@ -193,41 +193,42 @@ module Make(BI: Simulator.Branch_info) = struct
Gsl.Randist.(discrete rng (discrete_preproc v))
|> NSCodon.of_int_exn
let memo f =
let table = Caml.Hashtbl.create 253 in
fun x ->
match Caml.Hashtbl.find table x with
| y -> y
| exception Caml.Not_found ->
let y = f x in
Caml.Hashtbl.add table x y ;
y
let sequence_gillespie_direct tree ~root ~param =
let codon_rates = memo (fun (p, s) -> Evolution_model.rate_vector p s) in
Tree.propagate tree ~init:root ~node:Fn.const ~leaf:Fn.const ~branch:(fun seq b ->
let rec loop state remaining_time =
let state = Array.copy seq in
let n = Array.length state in
let rates =
Array.init n ~f:(fun i ->
codon_rates (param state i b, state.(i))
) in
let pos_rates = Array.map rates ~f:(fun r -> Owl.Stats.sum (r :> float array)) in
let total_rate = Array.reduce_exn pos_rates ~f:( +. ) in
let rate i = Evolution_model.rate_vector (param state i b) state.(i) in
let rates = Array.init n ~f:rate in
let pos_rate i = Owl.Stats.sum (rates.(i) :> float array) in
let pos_rates = Array.init n ~f:pos_rate in
let rec loop total_rate remaining_time =
let tau = Owl.Stats.exponential_rvs ~lambda:total_rate in
if Float.(tau > remaining_time) then state
else
let pos = Owl.Stats.categorical_rvs pos_rates in
let pos = Gsl.Randist.(discrete rng (discrete_preproc pos_rates)) in
let next_letter = symbol_sample (rates.(pos) :> float array) in
let next_state =
let t = Array.copy state in
t.(pos) <- next_letter ;
t
let outdated_positions =
[pos]
|> (fun l -> if pos - 1 >= 0 then (pos - 1) :: l else l)
|> (fun l -> if pos + 1 < n then (pos + 1) :: l else l)
in
state.(pos) <- next_letter ;
let delta_total_rate =
List.map outdated_positions ~f:(fun pos ->
let old = pos_rates.(pos) in
rates.(pos) <- rate pos ;
pos_rates.(pos) <- pos_rate pos ;
pos_rates.(pos) -. old
)
|> List.fold ~init:0. ~f:( +. )
in
loop next_state Float.(remaining_time - tau)
loop (total_rate +. delta_total_rate) Float.(remaining_time - tau)
in
loop seq (BI.length b)
let total_rate = Array.reduce_exn pos_rates ~f:( +. ) in
loop total_rate (BI.length b)
)
end
let demo seq_length ~rate_CpG ~branch_length =
let module Branch_info = struct
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment