Commit 8cadf5c8 authored by Philippe Veber's avatar Philippe Veber
Browse files

tk/Tdg09: implemented submatrix optimization for Model2

parent 7fa09ac2
......@@ -8,19 +8,33 @@ module Evolution_model = struct
exchangeability_matrix : Amino_acid.matrix ;
scale : float ;
}
let param_of_wag (wag : Wag.t) scale = {
scale ;
stationary_distribution = wag.freqs ;
exchangeability_matrix = wag.rate_matrix ;
}
let stationary_distribution p = p.stationary_distribution
let rate_matrix p =
Rate_matrix.Amino_acid.make (fun i j ->
let substitution_rate p i j =
p.scale *.
p.exchangeability_matrix.Amino_acid.%{i, j} *.
p.stationary_distribution.Amino_acid.%(j)
let rate_matrix p =
Rate_matrix.Amino_acid.make (substitution_rate p)
let rate_submatrix mask p =
let n = Array.length mask in
Rate_matrix.make n ~f:(fun i j ->
substitution_rate p mask.(i) mask.(j)
)
let transition_probability_submatrix mask p =
let m = rate_submatrix mask p in
fun t -> Matrix.(expm (scal_mul t m))
let transition_probability_matrix p =
let module V = Amino_acid.Vector in
let module M = Amino_acid.Matrix in
......@@ -59,12 +73,11 @@ module Evolution_model = struct
test_diagonal_matrix_exponential ()
end
let choose_aa p =
Amino_acid.Table.of_vector p
|> Amino_acid.Table.choose
module CTMC = Phylo_ctmc.Make(Amino_acid)
module CTMC = Phylo_ctmc
let tol = 0.001
......@@ -200,8 +213,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let f = Evolution_model.transition_probability_matrix p in
fun b -> (f (Branch_info.length b) :> mat)
in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:pi
let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies:pi
let clip f param =
if Float.(param.(0) > 3.) then Float.infinity
......@@ -239,49 +252,60 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
stationary_distribution : Amino_acid.vector ;
}
let log_likelihood ~exchangeability_matrix ~param:{ stationary_distribution ; scale } tree site =
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
mask : Amino_acid.t array ; (* corresponding amino acids *)
idx_of_aa : int Amino_acid.table ;
}
let log_likelihood ~exchangeability_matrix ~schema ~param:{ stationary_distribution ; scale } tree site =
let p = { Evolution_model.scale ; exchangeability_matrix ; stationary_distribution } in
let transition_matrix =
let f = Evolution_model.transition_probability_matrix p in
let f = Evolution_model.transition_probability_submatrix schema.mask p in
fun b -> (f (Branch_info.length b) :> mat)
in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies:(stationary_distribution :> Vector.t)
let leaf_state l =
aa_of_leaf_info site l
|> Option.map ~f:(Amino_acid.Table.get schema.idx_of_aa)
in
CTMC.pruning_with_missing_values tree
~nstates:schema.nz ~transition_matrix ~leaf_state
~root_frequencies:(stationary_distribution :> Vector.t)
let counts xs =
Amino_acid.Table.init (fun aa -> List.count xs ~f:(Amino_acid.equal aa))
type vec_schema = {
nz : int ; (* number of non-zero AA in count table *)
idx : int array ; (* indices of non-zero AA *)
}
let sparse_param_schema counts =
let k = (counts : int Amino_acid.table :> _ array) in
let idx, nz = Array.foldi k ~init:([], 0) ~f:(fun i ((assoc, nz) as acc) k_i ->
if k_i = 0 then acc else i :: assoc, nz + 1
if k_i = 0 then acc else Amino_acid.of_int_exn i :: assoc, nz + 1
)
in
let idx = Array.of_list idx in
{ nz ; idx }
let mask = Array.of_list idx in
let idx_of_aa =
let r = Amino_acid.Table.init (fun _ -> -1) in
Array.iteri mask ~f:(fun i aa -> Amino_acid.Table.set r aa i) ;
r
in
{ nz ; mask ; idx_of_aa }
let dense_param_schema counts =
let nz = Array.length (counts : int Amino_acid.table :> _ array) in
let idx = Array.init nz ~f:Fn.id in
{ nz ; idx }
let dense_param_schema () =
let nz = Amino_acid.card in
let mask = Array.init nz ~f:Amino_acid.of_int_exn in
let idx_of_aa = Amino_acid.Table.init Amino_acid.to_int in
{ nz ; mask ; idx_of_aa }
let profile_guess schema counts =
let counts = (counts : int Amino_acid.table :> _ array) in
let total_counts = Array.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in
Array.map schema.idx ~f:(fun idx -> Float.log (float (1 + counts.(idx)) /. total_counts))
let profile_guess schema counts = (* FIXME: are these pseudo-counts really useful? *)
let total_counts = Amino_acid.Table.fold counts ~init:0. ~f:(fun acc x -> 1. +. acc +. float x) in
Array.map schema.mask ~f:(fun idx -> Float.log (float (1 + Amino_acid.Table.get counts idx) /. total_counts))
let initial_param schema counts =
Array.append [| 0. |] (profile_guess schema counts)
let extract_frequencies ~offset schema param =
let r = Array.create ~len:Amino_acid.card 0. in
Array.iteri schema.idx ~f:(fun sparse_idx full_idx ->
r.(full_idx) <- Float.exp param.(sparse_idx + offset)
Array.iteri schema.mask ~f:(fun sparse_idx aa ->
r.((aa :> int)) <- Float.exp param.(sparse_idx + offset)
) ;
let s = Owl.Stats.sum r in
Amino_acid.Vector.init (fun aa -> r.((aa :> int)) /. s)
......@@ -289,7 +313,7 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
let param_schema ?(mode = `sparse) counts =
match mode with
| `sparse -> sparse_param_schema counts
| `dense -> dense_param_schema counts
| `dense -> dense_param_schema ()
let nelder_mead_init theta0 =
let c = ref (-1) in
......@@ -313,11 +337,12 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
|> counts
in
let schema = param_schema ?mode counts in
let module A = Alphabet.Make(struct let card = schema.nz end) in
let theta0 = initial_param schema counts in
let sample = nelder_mead_init theta0 in
let f p =
let param = decode_vec schema p in
-. log_likelihood ~exchangeability_matrix ~param tree site
-. log_likelihood ~exchangeability_matrix ~schema ~param tree site
in
let ll, p_star = Nelder_mead.minimize ~tol ?debug ~maxit:10_000 ~f:(Model1.clip f) ~sample () in
-. ll, schema, p_star
......@@ -369,8 +394,8 @@ module Make(Branch_info : Branch_info)(Leaf_info : Leaf_info)(Site : Site with t
| `Convergent -> (f1 bl :> mat)
in
let root_frequencies = (param.stationary_distribution0 :> Vector.t) in
let leaf_state = aa_of_leaf_info site in
CTMC.pruning_with_missing_values tree ~transition_matrix ~leaf_state ~root_frequencies
let leaf_state l = (aa_of_leaf_info site l :> int option) in
CTMC.pruning_with_missing_values tree ~nstates:Amino_acid.card ~transition_matrix ~leaf_state ~root_frequencies
let tuple_map (x, y) ~f = (f x, f y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment