Commit 8cadf5c8 authored by Philippe Veber's avatar Philippe Veber
Browse files

tk/Tdg09: implemented submatrix optimization for Model2

parent 7fa09ac2
...@@ -8,19 +8,33 @@ module Evolution_model = struct ...@@ -8,19 +8,33 @@ module Evolution_model = struct
exchangeability_matrix : Amino_acid.matrix ; exchangeability_matrix : Amino_acid.matrix ;
scale : float ; scale : float ;
} }
let param_of_wag (wag : Wag.t) scale = { let param_of_wag (wag : Wag.t) scale = {
scale ; scale ;
stationary_distribution = wag.freqs ; stationary_distribution = wag.freqs ;
exchangeability_matrix = wag.rate_matrix ; exchangeability_matrix = wag.rate_matrix ;
} }
let stationary_distribution p = p.stationary_distribution let stationary_distribution p = p.stationary_distribution
let substitution_rate p i j =
p.scale *.
p.exchangeability_matrix.Amino_acid.%{i, j} *.
p.stationary_distribution.Amino_acid.%(j)
let rate_matrix p = let rate_matrix p =
Rate_matrix.Amino_acid.make (fun i j -> Rate_matrix.Amino_acid.make (substitution_rate p)
p.scale *.
p.exchangeability_matrix.Amino_acid.%{i, j} *. let rate_submatrix mask p =
p.stationary_distribution.Amino_acid.%(j) let n = Array.length mask in
Rate_matrix.make n ~f:(fun i j ->
substitution_rate p mask.(i) mask.(j)
) )
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
fun t -> Matrix.(expm (scal_mul t m))
let transition_probability_matrix p = let transition_probability_matrix p =
let module V = Amino_acid.Vector in let module V = Amino_acid.Vector in
let module M = Amino_acid.Matrix in let module M = Amino_acid.Matrix in
...@@ -59,12 +73,11 @@ module Evolution_model = struct ...@@ -59,12 +73,11 @@ module Evolution_model = struct
test_diagonal_matrix_exponential () test_diagonal_matrix_exponential ()
end end
let choose_aa p = let choose_aa p =
Amino_acid.Table.of_vector p Amino_acid.Table.of_vector p
|> Amino_acid.Table.choose |> Amino_acid.Table.choose
module CTMC = Phylo_ctmc.Make(Amino_acid) module CTMC = Phylo_ctmc
let tol = 0.001 let tol = 0.001
...@@ -200,8 +213,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -200,8 +213,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let f = Evolution_model.transition_probability_matrix p in let f = Evolution_model.transition_probability_matrix p in
fun b -> (f (Branch_info.length b) :> mat) fun b -> (f (Branch_info.length b) :> mat)
in in
let leaf_state = aa_of_leaf_info site in let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:pi CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies:pi
let clip f param = let clip f param =
if Float.(param.(0) > 3.) then Float.infinity if Float.(param.(0) > 3.) then Float.infinity
...@@ -239,49 +252,60 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -239,49 +252,60 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
stationary_distribution : Amino_acid.vector ; stationary_distribution : Amino_acid.vector ;
} }
let log_likelihood ~exchangeability_matrix ~param:{ stationary_distribution ; scale } tree site = type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
mask : Amino_acid.t array ; (* corresponding amino acids *)
idx_of_aa : int Amino_acid.table ;
}
let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site =
let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in
let transition_matrix = let transition_matrix =
let f = Evolution_model.transition_probability_matrix p in let f = Evolution_model.transition_probability_submatrix schema.mask p in
fun b -> (f (Branch_info.length b) :> mat) fun b -> (f (Branch_info.length b) :> mat)
in in
let leaf_state = aa_of_leaf_info site in let leaf_state l =
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:(stationary_distribution :> Vector.t) aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree
~nstates:schema.nz ~transition_matrix ~leaf_state
~root_frequencies:(stationary_distribution :> Vector.t)
let counts xs = let counts xs =
Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa)) Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa))
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
idx : int array ; (* indices of non-zero AA *)
}
let sparse_param_schema counts = let sparse_param_schema counts =
let k = (counts : int Amino_acid.table :> _ array) in let k = (counts : int Amino_acid.table :> _ array) in
let idx, nz = Array.foldi k ~init:([], 0) ~f:(fun i ((assoc, nz) as acc) k_i -> let idx, nz = Array.foldi k ~init:([], 0) ~f:(fun i ((assoc, nz) as acc) k_i ->
if k_i = 0 then acc else i :: assoc, nz + 1 if k_i = 0 then acc else Amino_acid.of_int_exn i :: assoc, nz + 1
) )
in in
let idx = Array.of_list idx in let mask = Array.of_list idx in
{ nz ; idx } let idx_of_aa =
let r = Amino_acid.Table.init (fun _ -> -1) in
Array.iteri mask ~f:(fun i aa -> Amino_acid.Table.set r aa i) ;
r
in
{ nz ; mask ; idx_of_aa }
let dense_param_schema counts = let dense_param_schema () =
let nz = Array.length (counts : int Amino_acid.table :> _ array) in let nz = Amino_acid.card in
let idx = Array.init nz ~f:Fn.id in let mask = Array.init nz ~f:Amino_acid.of_int_exn in
{ nz ; idx } let idx_of_aa = Amino_acid.Table.init Amino_acid.to_int in
{ nz ; mask ; idx_of_aa }
let profile_guess schema counts = let profile_guess schema counts = (* FIXME: are these pseudo-counts really useful? *)
let counts = (counts : int Amino_acid.table :> _ array) in let total_counts = Amino_acid.Table.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in
let total_counts = Array.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in Array.map schema.mask ~f:(fun idx -> Float.log (float (1 + Amino_acid.Table.get counts idx) /. total_counts))
Array.map schema.idx ~f:(fun idx -> Float.log (float (1 + counts.(idx)) /. total_counts))
let initial_param schema counts = let initial_param schema counts =
Array.append [| 0. |] (profile_guess schema counts) Array.append [| 0. |] (profile_guess schema counts)
let extract_frequencies ~offset schema param = let extract_frequencies ~offset schema param =
let r = Array.create ~len:Amino_acid.card 0. in let r = Array.create ~len:Amino_acid.card 0. in
Array.iteri schema.idx ~f:(fun sparse_idx full_idx -> Array.iteri schema.mask ~f:(fun sparse_idx aa ->
r.(full_idx) <- Float.exp param.(sparse_idx + offset) r.((aa :> int)) <- Float.exp param.(sparse_idx + offset)
) ; ) ;
let s = Owl.Stats.sum r in let s = Owl.Stats.sum r in
Amino_acid.Vector.init (fun aa -> r.((aa :> int)) /. s) Amino_acid.Vector.init (fun aa -> r.((aa :> int)) /. s)
...@@ -289,7 +313,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -289,7 +313,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let param_schema ?(mode = `sparse) counts = let param_schema ?(mode = `sparse) counts =
match mode with match mode with
| `sparse -> sparse_param_schema counts | `sparse -> sparse_param_schema counts
| `dense -> dense_param_schema counts | `dense -> dense_param_schema ()
let nelder_mead_init theta0 = let nelder_mead_init theta0 =
let c = ref (-1) in let c = ref (-1) in
...@@ -313,11 +337,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -313,11 +337,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
|> counts |> counts
in in
let schema = param_schema ?mode counts in let schema = param_schema ?mode counts in
let module A = Alphabet.Make(struct let card = schema.nz end) in
let theta0 = initial_param schema counts in let theta0 = initial_param schema counts in
let sample = nelder_mead_init theta0 in let sample = nelder_mead_init theta0 in
let f p = let f p =
let param = decode_vec schema p in let param = decode_vec schema p in
-. log_likelihood ~exchangeability_matrix ~param tree site -. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
-. ll, schema, p_star -. ll, schema, p_star
...@@ -369,8 +394,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t ...@@ -369,8 +394,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Convergent -> (f1 bl :> mat) | `Convergent -> (f1 bl :> mat)
in in
let root_frequencies = (param.stationary_distribution0 :> Vector.t) in let root_frequencies = (param.stationary_distribution0 :> Vector.t) in
let leaf_state = aa_of_leaf_info site in let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies
let tuple_map (x, y) ~f = (f x, f y) let tuple_map (x, y) ~f = (f x, f y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment