Commit 3851e846 authored by Philippe Veber's avatar Philippe Veber
Browse files

tk/Tdg09: combined diagonalization and submatrix optimizations

original and optimized versions were run and benchmarked on the
following program:

```ocaml
let () =
  let open Codepi.Orthomam in
  let loggers = [ Bistro_utils.Console_logger.create () ] in
  let db = Codepitk.Orthomam_db.make "_runs/omm" in
  let q =
    search_alignments ~pat:"*GPR87" db
    |> List.hd
    |> query ~convergent_species:(Bistro.Workflow.data species_with_echolocation)
  in
  inhouse_tdg09 q
  |> Bistro.Workflow.path
  |> Bistro_engine.Scheduler.simple_eval_exn ~loggers
  |> print_endline
```

execution time got from 24min to 18s, apparently with very little
changes on pvalue results (after the 6th decimal)
parent 8cadf5c8
......@@ -2,6 +2,9 @@ open Core_kernel
open Phylogenetics
open Phylogenetics.Linear_algebra.Lacaml
let project_vector (mask : Amino_acid.t array) (v : Amino_acid.vector) =
Vector.init (Array.length mask) ~f:(fun i -> Amino_acid.Vector.get v (mask.(i) :> int))
module Evolution_model = struct
type param = {
stationary_distribution : Amino_acid.vector ;
......@@ -31,15 +34,10 @@ module Evolution_model = struct
substitution_rate p mask.(i) mask.(j)
)
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
fun t -> Matrix.(expm (scal_mul t m))
let transition_probability_matrix p =
let module V = Amino_acid.Vector in
let module M = Amino_acid.Matrix in
let m = rate_matrix p in
let sqrt_pi = V.map p.stationary_distribution ~f:Float.sqrt in
let diag_expm m pi =
let module V = Vector in
let module M = Matrix in
let sqrt_pi = V.map pi ~f:Float.sqrt in
let diag_pi = M.diagm sqrt_pi in
let diag_pi_inv = V.map sqrt_pi ~f:(fun v -> 1. /. v) |> M.diagm in
let m' = M.(dot diag_pi @@ dot m diag_pi_inv) in
......@@ -47,9 +45,17 @@ module Evolution_model = struct
let transform_matrix = M.dot diag_pi_inv step_transform_matrix in
let rev_transform_matrix = M.dot (M.transpose step_transform_matrix) diag_pi in
fun t ->
let exp_matrix = Amino_acid.Vector.(exp (scal_mul t d_vec))
|> Amino_acid.Matrix.diagm in
Amino_acid.Matrix.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
let exp_matrix = V.(exp (scal_mul t d_vec)) |> M.diagm in
M.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
let pi = project_vector mask p.stationary_distribution in
diag_expm m pi
let transition_probability_matrix p =
let m = rate_matrix p in
diag_expm (m :> mat) (p.stationary_distribution :> vec)
let test_diagonal_matrix_exponential () =
let stationary_distribution = Amino_acid.random_profile 0.5 in
......@@ -67,7 +73,7 @@ module Evolution_model = struct
let diag_exp_matrix = transition_probability_matrix p t in
let m = rate_matrix p in
let exp_matrix = Amino_acid.Matrix.(expm (scal_mul t m)) in
Amino_acid.Matrix.robust_equal ~tol:1e-10 diag_exp_matrix exp_matrix
Matrix.robust_equal ~tol:1e-10 diag_exp_matrix (exp_matrix :> mat)
let%test "Matrix exponential through diagonalisation matches naive implementation" =
test_diagonal_matrix_exponential ()
......@@ -247,17 +253,17 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
end
module Model2 = struct
type param = {
scale : float ;
stationary_distribution : Amino_acid.vector ;
}
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
mask : Amino_acid.t array ; (* corresponding amino acids *)
idx_of_aa : int Amino_acid.table ;
}
type param = {
scale : float ;
stationary_distribution : Amino_acid.vector ;
}
let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site =
let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in
let transition_matrix =
......@@ -270,7 +276,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
in
CTMC.pruning_with_missing_values tree
~nstates:schema.nz ~transition_matrix ~leaf_state
~root_frequencies:(stationary_distribution :> Vector.t)
~root_frequencies:(project_vector schema.mask stationary_distribution)
let counts xs =
Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa))
......@@ -379,10 +385,10 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Convergent -> f param.stationary_distribution1
| _ -> assert false
let log_likelihood ~exchangeability_matrix ~param tree site =
let log_likelihood ~exchangeability_matrix ~schema ~param tree site =
let f cond =
evolution_model_param exchangeability_matrix param cond
|> Evolution_model.transition_probability_matrix
|> Evolution_model.transition_probability_submatrix schema.Model2.mask
in
let transition_matrix =
let f0 = f `Ancestral in (* pre-computation *)
......@@ -393,9 +399,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Ancestral -> (f0 bl :> mat)
| `Convergent -> (f1 bl :> mat)
in
let root_frequencies = (param.stationary_distribution0 :> Vector.t) in
let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies
let root_frequencies = project_vector schema.mask param.stationary_distribution0 in
let leaf_state l =
aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree ~nstates:schema.nz ~transition_matrix ~leaf_state ~root_frequencies
let tuple_map (x, y) ~f = (f x, f y)
......@@ -442,7 +451,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
in
let f vec =
let param = decode_vec schema vec in
-. log_likelihood ~exchangeability_matrix ~param tree site
-. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in
let sample = Model2.nelder_mead_init theta0 in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
......
......@@ -9,7 +9,7 @@ module Evolution_model : sig
val param_of_wag : Wag.t -> float -> param
val rate_matrix : param -> Amino_acid.matrix
val stationary_distribution : param -> Amino_acid.vector
val transition_probability_matrix : param -> float -> Amino_acid.matrix
val transition_probability_matrix : param -> float -> Linear_algebra.Lacaml.mat
end
type likelihood_ratio_test = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment