Commit f02d2187 authored by Philippe Veber's avatar Philippe Veber
Browse files

Merge branch 'tdg09-submatrix-optimization'

parents 7fa09ac2 3851e846
......@@ -2,30 +2,42 @@ open Core_kernel
open Phylogenetics
open Phylogenetics.Linear_algebra.Lacaml
let project_vector (mask : Amino_acid.t array) (v : Amino_acid.vector) =
Vector.init (Array.length mask) ~f:(fun i -> Amino_acid.Vector.get v (mask.(i) :> int))
module Evolution_model = struct
type param = {
stationary_distribution : Amino_acid.vector ;
exchangeability_matrix : Amino_acid.matrix ;
scale : float ;
}
let param_of_wag (wag : Wag.t) scale = {
scale ;
stationary_distribution = wag.freqs ;
exchangeability_matrix = wag.rate_matrix ;
}
let stationary_distribution p = p.stationary_distribution
let rate_matrix p =
Rate_matrix.Amino_acid.make (fun i j ->
let substitution_rate p i j =
p.scale *.
p.exchangeability_matrix.Amino_acid.%{i, j} *.
p.stationary_distribution.Amino_acid.%(j)
let rate_matrix p =
Rate_matrix.Amino_acid.make (substitution_rate p)
let rate_submatrix mask p =
let n = Array.length mask in
Rate_matrix.make n ~f:(fun i j ->
substitution_rate p mask.(i) mask.(j)
)
let transition_probability_matrix p =
let module V = Amino_acid.Vector in
let module M = Amino_acid.Matrix in
let m = rate_matrix p in
let sqrt_pi = V.map p.stationary_distribution ~f:Float.sqrt in
let diag_expm m pi =
let module V = Vector in
let module M = Matrix in
let sqrt_pi = V.map pi ~f:Float.sqrt in
let diag_pi = M.diagm sqrt_pi in
let diag_pi_inv = V.map sqrt_pi ~f:(fun v -> 1. /. v) |> M.diagm in
let m' = M.(dot diag_pi @@ dot m diag_pi_inv) in
......@@ -33,9 +45,17 @@ module Evolution_model = struct
let transform_matrix = M.dot diag_pi_inv step_transform_matrix in
let rev_transform_matrix = M.dot (M.transpose step_transform_matrix) diag_pi in
fun t ->
let exp_matrix = Amino_acid.Vector.(exp (scal_mul t d_vec))
|> Amino_acid.Matrix.diagm in
Amino_acid.Matrix.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
let exp_matrix = V.(exp (scal_mul t d_vec)) |> M.diagm in
M.(dot transform_matrix @@ dot exp_matrix rev_transform_matrix)
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
let pi = project_vector mask p.stationary_distribution in
diag_expm m pi
let transition_probability_matrix p =
let m = rate_matrix p in
diag_expm (m :> mat) (p.stationary_distribution :> vec)
let test_diagonal_matrix_exponential () =
let stationary_distribution = Amino_acid.random_profile 0.5 in
......@@ -53,18 +73,17 @@ module Evolution_model = struct
let diag_exp_matrix = transition_probability_matrix p t in
let m = rate_matrix p in
let exp_matrix = Amino_acid.Matrix.(expm (scal_mul t m)) in
Amino_acid.Matrix.robust_equal ~tol:1e-10 diag_exp_matrix exp_matrix
Matrix.robust_equal ~tol:1e-10 diag_exp_matrix (exp_matrix :> mat)
let%test "Matrix exponential through diagonalisation matches naive implementation" =
test_diagonal_matrix_exponential ()
end
let choose_aa p =
Amino_acid.Table.of_vector p
|> Amino_acid.Table.choose
module CTMC = Phylo_ctmc.Make(Amino_acid)
module CTMC = Phylo_ctmc
let tol = 0.001
......@@ -200,8 +219,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let f = Evolution_model.transition_probability_matrix p in
fun b -> (f (Branch_info.length b) :> mat)
in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:pi
let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies:pi
let clip f param =
if Float.(param.(0) > 3.) then Float.infinity
......@@ -234,54 +253,65 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
end
module Model2 = struct
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
mask : Amino_acid.t array ; (* corresponding amino acids *)
idx_of_aa : int Amino_acid.table ;
}
type param = {
scale : float ;
stationary_distribution : Amino_acid.vector ;
}
let log_likelihood ~exchangeability_matrix ~param:{ stationary_distribution ; scale } tree site =
let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site =
let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in
let transition_matrix =
let f = Evolution_model.transition_probability_matrix p in
let f = Evolution_model.transition_probability_submatrix schema.mask p in
fun b -> (f (Branch_info.length b) :> mat)
in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:(stationary_distribution :> Vector.t)
let leaf_state l =
aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree
~nstates:schema.nz ~transition_matrix ~leaf_state
~root_frequencies:(project_vector schema.mask stationary_distribution)
let counts xs =
Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa))
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
idx : int array ; (* indices of non-zero AA *)
}
let sparse_param_schema counts =
let k = (counts : int Amino_acid.table :> _ array) in
let idx, nz = Array.foldi k ~init:([], 0) ~f:(fun i ((assoc, nz) as acc) k_i ->
if k_i = 0 then acc else i :: assoc, nz + 1
if k_i = 0 then acc else Amino_acid.of_int_exn i :: assoc, nz + 1
)
in
let idx = Array.of_list idx in
{ nz ; idx }
let mask = Array.of_list idx in
let idx_of_aa =
let r = Amino_acid.Table.init (fun _ -> -1) in
Array.iteri mask ~f:(fun i aa -> Amino_acid.Table.set r aa i) ;
r
in
{ nz ; mask ; idx_of_aa }
let dense_param_schema counts =
let nz = Array.length (counts : int Amino_acid.table :> _ array) in
let idx = Array.init nz ~f:Fn.id in
{ nz ; idx }
let dense_param_schema () =
let nz = Amino_acid.card in
let mask = Array.init nz ~f:Amino_acid.of_int_exn in
let idx_of_aa = Amino_acid.Table.init Amino_acid.to_int in
{ nz ; mask ; idx_of_aa }
let profile_guess schema counts =
let counts = (counts : int Amino_acid.table :> _ array) in
let total_counts = Array.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in
Array.map schema.idx ~f:(fun idx -> Float.log (float (1 + counts.(idx)) /. total_counts))
let profile_guess schema counts = (* FIXME: are these pseudo-counts really useful? *)
let total_counts = Amino_acid.Table.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in
Array.map schema.mask ~f:(fun idx -> Float.log (float (1 + Amino_acid.Table.get counts idx) /. total_counts))
let initial_param schema counts =
Array.append [| 0. |] (profile_guess schema counts)
let extract_frequencies ~offset schema param =
let r = Array.create ~len:Amino_acid.card 0. in
Array.iteri schema.idx ~f:(fun sparse_idx full_idx ->
r.(full_idx) <- Float.exp param.(sparse_idx + offset)
Array.iteri schema.mask ~f:(fun sparse_idx aa ->
r.((aa :> int)) <- Float.exp param.(sparse_idx + offset)
) ;
let s = Owl.Stats.sum r in
Amino_acid.Vector.init (fun aa -> r.((aa :> int)) /. s)
......@@ -289,7 +319,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let param_schema ?(mode = `sparse) counts =
match mode with
| `sparse -> sparse_param_schema counts
| `dense -> dense_param_schema counts
| `dense -> dense_param_schema ()
let nelder_mead_init theta0 =
let c = ref (-1) in
......@@ -313,11 +343,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
|> counts
in
let schema = param_schema ?mode counts in
let module A = Alphabet.Make(struct let card = schema.nz end) in
let theta0 = initial_param schema counts in
let sample = nelder_mead_init theta0 in
let f p =
let param = decode_vec schema p in
-. log_likelihood ~exchangeability_matrix ~param tree site
-. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
-. ll, schema, p_star
......@@ -354,10 +385,10 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Convergent -> f param.stationary_distribution1
| _ -> assert false
let log_likelihood ~exchangeability_matrix ~param tree site =
let log_likelihood ~exchangeability_matrix ~schema ~param tree site =
let f cond =
evolution_model_param exchangeability_matrix param cond
|> Evolution_model.transition_probability_matrix
|> Evolution_model.transition_probability_submatrix schema.Model2.mask
in
let transition_matrix =
let f0 = f `Ancestral in (* pre-computation *)
......@@ -368,9 +399,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Ancestral -> (f0 bl :> mat)
| `Convergent -> (f1 bl :> mat)
in
let root_frequencies = (param.stationary_distribution0 :> Vector.t) in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies
let root_frequencies = project_vector schema.mask param.stationary_distribution0 in
let leaf_state l =
aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree ~nstates:schema.nz ~transition_matrix ~leaf_state ~root_frequencies
let tuple_map (x, y) ~f = (f x, f y)
......@@ -417,7 +451,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
in
let f vec =
let param = decode_vec schema vec in
-. log_likelihood ~exchangeability_matrix ~param tree site
-. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in
let sample = Model2.nelder_mead_init theta0 in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
......
......@@ -9,7 +9,7 @@ module Evolution_model : sig
val param_of_wag : Wag.t -> float -> param
val rate_matrix : param -> Amino_acid.matrix
val stationary_distribution : param -> Amino_acid.vector
val transition_probability_matrix : param -> float -> Amino_acid.matrix
val transition_probability_matrix : param -> float -> Linear_algebra.Lacaml.mat
end
type likelihood_ratio_test = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment