Commit 3851e846 authored by Philippe Veber's avatar Philippe Veber
Browse files

tk/Tdg09: combined diagonalization and submatrix optimizations

original and optimized versions were run and benchmarked on the
following program:

```ocaml
let () =
  let open Codepi.Orthomam in
  let loggers = [ Bistro_utils.Console_logger.create () ] in
  let db = Codepitk.Orthomam_db.make "_runs/omm" in
  let q =
    search_alignments ~pat:"*GPR87" db
    |> List.hd
    |> query ~convergent_species:(Bistro.Workflow.data species_with_echolocation)
  in
  inhouse_tdg09 q
  |> Bistro.Workflow.path
  |> Bistro_engine.Scheduler.simple_eval_exn ~loggers
  |> print_endline
```

execution time got from 24min to 18s, apparently with very little
changes on pvalue results (after the 6th decimal)
parent 8cadf5c8
...@@ -2,6 +2,9 @@ open Core_kernel ...@@ -2,6 +2,9 @@ open Core_kernel
open Phylogenetics open Phylogenetics
open Phylogenetics.Linear_algebra.Lacaml open Phylogenetics.Linear_algebra.Lacaml
let project_vector (mask : Amino_acid.t array) (v : Amino_acid.vector) =
Vector.init (Array.length mask) ~f:(fun i -> Amino_acid.Vector.get v (mask.(i) :> int))
module Evolution_model = struct module Evolution_model = struct
type param = { type param = {
stationary_distribution : Amino_acid.vector ; stationary_distribution : Amino_acid.vector ;
...@@ -31,15 +34,10 @@ module Evolution_model = struct ...@@ -31,15 +34,10 @@ module Evolution_model = struct
substitution_rate p mask.(i) mask.(j) substitution_rate p mask.(i) mask.(j)
) )
let transition_probability_submatrix mask p = let diag_expm m pi =
let m = rate_submatrix mask p in let module V = Vector in
fun t -> Matrix.(expm (scal_mul t m)) let module M = Matrix in
let sqrt_pi = V.map pi ~f:Float.sqrt in
let transition_probability_matrix p =
let module V = Amino_acid.Vector in
let module M = Amino_acid.Matrix in
let m = rate_matrix p in
let sqrt_pi = V.map p.stationary_distribution ~f:Float.sqrt in
let diag_pi = M.diagm sqrt_pi in let diag_pi = M.diagm sqrt_pi in
let diag_pi_inv = V.map sqrt_pi ~f:(fun v -> 1. /. v) |> M.diagm in let diag_pi_inv = V.map sqrt_pi ~f:(fun v -> 1. /. v) |> M.diagm in
let m' = M.(dot diag_pi @@ dot m diag_pi_inv) in let m' = M.(dot diag_pi @@ dot m diag_pi_inv) in
...@@ -47,9 +45,17 @@ module Evolution_model = struct ...@@ -47,9 +45,17 @@ module Evolution_model = struct
let transform_matrix = M.dot diag_pi_inv step_transform_matrix in let transform_matrix = M.dot diag_pi_inv step_transform_matrix in
let rev_transform_matrix = M.dot (M.transpose step_transform_matrix) diag_pi in let rev_transform_matrix = M.dot (M.transpose step_transform_matrix) diag_pi in
fun t -> fun t ->
let exp_matrix = Amino_acid.Vector.(exp (scal_mul t d_vec)) let exp_matrix = V.(exp (scal_mul t d_vec)) |> M.diagm in
|> Amino_acid.Matrix.diagm in M.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
Amino_acid.Matrix.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
let pi = project_vector mask p.stationary_distribution in
diag_expm m pi
let transition_probability_matrix p =
let m = rate_matrix p in
diag_expm (m :> mat) (p.stationary_distribution :> vec)
let test_diagonal_matrix_exponential () = let test_diagonal_matrix_exponential () =
let stationary_distribution = Amino_acid.random_profile 0.5 in let stationary_distribution = Amino_acid.random_profile 0.5 in
...@@ -67,7 +73,7 @@ module Evolution_model = struct ...@@ -67,7 +73,7 @@ module Evolution_model = struct
let diag_exp_matrix = transition_probability_matrix p t in let diag_exp_matrix = transition_probability_matrix p t in
let m = rate_matrix p in let m = rate_matrix p in
let exp_matrix = Amino_acid.Matrix.(expm (scal_mul t m)) in let exp_matrix = Amino_acid.Matrix.(expm (scal_mul t m)) in
Amino_acid.Matrix.robust_equal ~tol:1e-10 diag_exp_matrix exp_matrix Matrix.robust_equal ~tol:1e-10 diag_exp_matrix (exp_matrix :> mat)
let%test "Matrix exponential through diagonalisation matches naive implementation" = let%test "Matrix exponential through diagonalisation matches naive implementation" =
test_diagonal_matrix_exponential () test_diagonal_matrix_exponential ()
...@@ -247,17 +253,17 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -247,17 +253,17 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
end end
module Model2 = struct module Model2 = struct
type param = {
scale : float ;
stationary_distribution : Amino_acid.vector ;
}
type vec_schema = { type vec_schema = {
nz : int ; (* number of non-zero AA in count table *) nz : int ; (* number of non-zero AA in count table *)
mask : Amino_acid.t array ; (* corresponding amino acids *) mask : Amino_acid.t array ; (* corresponding amino acids *)
idx_of_aa : int Amino_acid.table ; idx_of_aa : int Amino_acid.table ;
} }
type param = {
scale : float ;
stationary_distribution : Amino_acid.vector ;
}
let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site = let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site =
let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in
let transition_matrix = let transition_matrix =
...@@ -270,7 +276,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -270,7 +276,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
in in
CTMC.pruning_with_missing_values tree CTMC.pruning_with_missing_values tree
~nstates:schema.nz ~transition_matrix ~leaf_state ~nstates:schema.nz ~transition_matrix ~leaf_state
~root_frequencies:(stationary_distribution :> Vector.t) ~root_frequencies:(project_vector schema.mask stationary_distribution)
let counts xs = let counts xs =
Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa)) Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa))
...@@ -379,10 +385,10 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -379,10 +385,10 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Convergent -> f param.stationary_distribution1 | `Convergent -> f param.stationary_distribution1
| _ -> assert false | _ -> assert false
let log_likelihood ~exchangeability_matrix ~param tree site = let log_likelihood ~exchangeability_matrix ~schema ~param tree site =
let f cond = let f cond =
evolution_model_param exchangeability_matrix param cond evolution_model_param exchangeability_matrix param cond
|> Evolution_model.transition_probability_matrix |> Evolution_model.transition_probability_submatrix schema.Model2.mask
in in
let transition_matrix = let transition_matrix =
let f0 = f `Ancestral in (* pre-computation *) let f0 = f `Ancestral in (* pre-computation *)
...@@ -393,9 +399,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -393,9 +399,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Ancestral -> (f0 bl :> mat) | `Ancestral -> (f0 bl :> mat)
| `Convergent -> (f1 bl :> mat) | `Convergent -> (f1 bl :> mat)
in in
let root_frequencies = (param.stationary_distribution0 :> Vector.t) in let root_frequencies = project_vector schema.mask param.stationary_distribution0 in
let leaf_state l = (aa_of_leaf_info site l :> int option) in let leaf_state l =
CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree ~nstates:schema.nz ~transition_matrix ~leaf_state ~root_frequencies
let tuple_map (x, y) ~f = (f x, f y) let tuple_map (x, y) ~f = (f x, f y)
...@@ -442,7 +451,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -442,7 +451,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
in in
let f vec = let f vec =
let param = decode_vec schema vec in let param = decode_vec schema vec in
-. log_likelihood ~exchangeability_matrix ~param tree site -. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in in
let sample = Model2.nelder_mead_init theta0 in let sample = Model2.nelder_mead_init theta0 in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
......
...@@ -9,7 +9,7 @@ module Evolution_model : sig ...@@ -9,7 +9,7 @@ module Evolution_model : sig
val param_of_wag : Wag.t -> float -> param val param_of_wag : Wag.t -> float -> param
val rate_matrix : param -> Amino_acid.matrix val rate_matrix : param -> Amino_acid.matrix
val stationary_distribution : param -> Amino_acid.vector val stationary_distribution : param -> Amino_acid.vector
val transition_probability_matrix : param -> float -> Amino_acid.matrix val transition_probability_matrix : param -> float -> Linear_algebra.Lacaml.mat
end end
type likelihood_ratio_test = { type likelihood_ratio_test = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment