Commit fa3d4f8e authored by Louis Duchemin's avatar Louis Duchemin
Browse files

Diffelstein bug-hunt still going on

parent 93cfef4f
......@@ -56,11 +56,14 @@ let trans_matrix branch_length param =
Tensor.mul mat (exp_term branch_length param.mu)
|> Tensor.(add (f 0.25))
let logger likelihood step param =
let logger likelihood step param lr =
Tensor.(no_grad (fun () ->
sprintf
"Lik : %f ; Step : %d ; mu : %f ; grad : %f"
(Tensor.(likelihood |> to_float0_exn)) step (Tensor.to_float1_exn param.stationary_distribution).(1) !x
"Lik : %f ; Step : %d ; mu : %f ; lr : %f"
(Tensor.(likelihood |> to_float0_exn))
step
(Tensor.to_float1_exn param.stationary_distribution).(1)
lr
|> print_endline;
))
......
......@@ -45,7 +45,7 @@ let trans_matrix branch_length param =
|> Tensor.diag ~diagonal:0 in
let mat = Tensor.matmul param.exchangeability_matrix diag_pi in
Tensor.(matrix_exp ((f branch_length * param.mu) * mat))
(* |> Tensor.transpose ~dim0:0 ~dim1:1 *)
(* |> Tensor.transpose ~dim0:0 ~dim1:1 *)
let wag = P.Wag.parse "data/wag.dat"
let param = param_of_wag wag mu
......@@ -58,14 +58,14 @@ let update_param param ~lr =
zero_grad param.mu;
)
let logger likelihood step param =
let logger likelihood step param lr =
Tensor.(no_grad (fun () ->
sprintf
"Lik : %f ; Step : %d ; mu : %f"
"Lik : %f ; Step : %d ; mu : %f ; lr : %f"
(Tensor.(likelihood |> to_float0_exn))
step
(Tensor.to_float0_exn param.mu)
(*!x*)
lr
|> print_endline;
))
......
open Core_kernel
open Phylogenetics
open Torch
module Tk = Codepitk
module Leaf_info = struct
type t = int
type species = int
let condition _ = `Ancestral
let species x = x
end
module Branch_info = struct
type t = float
let length l = l
let condition _ = `Ancestral
end
module Site = struct
type t = Amino_acid.t array
type species = int
let get_aa (al:t) sp = Some al.(sp)
end
module TDG09 = Tk.Tdg09.Make(Branch_info)(Leaf_info)(Site)
let wag = Wag.parse "data/wag.dat"
let newick_str =
(* "(t2:0.3934020305,t1:1.);" *)
"(t2:2.400985762,(t5:2.20248759,(t1:0.0700889683,t3:0.0700889683):2.132398622):0.1984981724);"
(* "((t2:2.400985762,(t5:2.20248759,(t1:0.0700889683,t3:0.0700889683):2.132398622):0.1984981724):45.82815443,(t4:2.417988309,t6:2.417988309):45.81115189):3.793640769;" *)
let () = Gsl.Rng.set_default_seed @@ Nativeint.of_int 0;;
let rng = Gsl.Rng.(make (default ()))
let tree = Newick.from_string_exn newick_str
let scale = 2.
let sim_tree = Newick.with_inner_tree tree ~f:(fun tree ->
let tree = Tree.map tree ~node:Fn.id ~leaf:Fn.id
~branch:(fun b -> Option.value_exn b.length)
in
TDG09.Model1.simulate_site tree ~rng ~param:scale
~exchangeability_matrix:wag.rate_matrix
~stationary_distribution:wag.freqs
)
type diffparam = {
stationary_distribution : Tensor.t ;
scale : Tensor.t ;
exchangeability_matrix: Tensor.t;
}
let param_of_wag (wag:Wag.t) scale = {
scale=Tensor.f scale |> Tensor.set_requires_grad ~r:true;
stationary_distribution =
(* *)
(* Amino_acid.Vector.to_array wag.freqs *)
Array.init Amino_acid.card ~f:(fun i ->
(* 0. *)
(* if i = 1 || i = 7 then 10.
else if i = 17 then 20.
else 0. *)
if i = 0 then 1. else 0.
)
|> Tensor.of_float1
|> Tensor.set_requires_grad ~r:true ;
exchangeability_matrix =
Array.init Amino_acid.card ~f:(fun i ->
Array.init Amino_acid.card ~f:(fun j ->
Amino_acid.Matrix.get wag.rate_matrix i j
)
)
|> Tensor.of_float2
}
let rec difftree_of_simtree tree =
match tree with
| Tree.Leaf l -> Tk.Diffelstein.Leaf (Amino_acid.to_int l)
| Tree.Node n ->
match List1.to_list n.branches with
| [Branch l ; Branch r] -> Tk.Diffelstein.(
Node (
{length=l.data; tip=(difftree_of_simtree l.tip)},
{length=r.data; tip=(difftree_of_simtree r.tip)}
)
)
| _ -> failwith "non binary node found"
let difftree = difftree_of_simtree sim_tree
let diffparams = param_of_wag wag 0.
let stationary_distribution_of_params param =
Tensor.(
exp param.stationary_distribution /
sum (exp param.stationary_distribution)
)
let number_leaves t =
let id = ref (-1) in
Tree.map t ~leaf:(fun _ -> incr id ; !id)
~node:Fn.id ~branch:Fn.id
let trans_matrix branch_length param =
let diag_pi = stationary_distribution_of_params param
|> Tensor.diag ~diagonal:0 in
let mat = Tensor.matmul param.exchangeability_matrix diag_pi in
Tensor.(matrix_exp ((f branch_length * exp param.scale) * mat))
let update_param param ~lr =
Tensor.(
pp Format.std_formatter @@ grad param.stationary_distribution;
pp Format.std_formatter @@ grad param.scale;
param.stationary_distribution += grad param.stationary_distribution * f lr;
param.scale += grad param.scale * f lr;
zero_grad param.stationary_distribution;
zero_grad param.scale;
)
let param_to_tdg09 param =
{ Tk.Tdg09.Evolution_model.
scale = Tensor.to_float0_exn param.scale |> Float.exp;
exchangeability_matrix = Tensor.to_float2_exn param.exchangeability_matrix
|> Amino_acid.Matrix.of_arrays_exn;
stationary_distribution = Tensor.to_float1_exn param.stationary_distribution
|> Array.map ~f:Float.exp
|> Amino_acid.Vector.of_array_exn
|> Amino_acid.Vector.normalize;
}
let () =
let diff_transmat = trans_matrix 1. diffparams in
let tdg09_params = param_to_tdg09 diffparams in
let tdg09_transmat = Tk.Tdg09.Evolution_model.transition_probability_matrix
tdg09_params 1. in
Tensor.pp Format.std_formatter diff_transmat; (* pour la postérité, faire la différence des deux matrices et afficher le delta le plus grand en valeur absolue *)
Linear_algebra.Lacaml.Matrix.pp Format.std_formatter tdg09_transmat;
let torch_lik = Tk.Diffelstein.pruning difftree
~param:diffparams
~trans_matrix
~stationary_distribution:stationary_distribution_of_params
~alphabet_size:Amino_acid.card
and lacaml_lik = Tk.Diffelstein.Lacaml_pruning.pruning difftree
~param:tdg09_params
~trans_matrix:(fun l p -> Tk.Diffelstein.Lacaml_tensor.Matrix (Tk.Tdg09.Evolution_model.transition_probability_matrix p l))
~stationary_distribution:(fun p -> Tk.Diffelstein.Lacaml_tensor.Vector (Tk.Tdg09.Evolution_model.stationary_distribution p :> Phylogenetics.Linear_algebra.Lacaml.vec))
~alphabet_size:Amino_acid.card
and phylo_ctmc_lik =
Phylo_ctmc.pruning
sim_tree
~nstates:20
~transition_matrix:(fun l -> Tk.Tdg09.Evolution_model.transition_probability_matrix
tdg09_params l)
~leaf_state:Amino_acid.to_int
~root_frequencies:(Tk.Tdg09.Evolution_model.stationary_distribution tdg09_params :> Phylogenetics.Linear_algebra.Lacaml.vec)
in
Out_channel.newline Out_channel.stdout;
Tensor.to_float0_exn torch_lik |> string_of_float |> print_endline;
(match lacaml_lik with
|Float lik -> lik |> string_of_float |> print_endline;
| _ -> failwith "Something bad happened");
string_of_float (Float.exp phylo_ctmc_lik) |> print_endline;
let logger likelihood step param lr =
Tensor.(no_grad (fun () ->
if Int.(step mod 100 = 0) then
printf
"Lik : %f ; Step : %d ; scale : %f ; lr : %f\n%!"
(Tensor.(log likelihood |> to_float0_exn))
step
(Tensor.to_float0_exn param.scale |> Float.exp)
lr
))
in
let dif_param = Tk.Diffelstein.likelihood_optim difftree
~param:diffparams
~trans_matrix
~stationary_distribution:stationary_distribution_of_params
~alphabet_size:Amino_acid.card
~lr:100.
~steps:2
~update_param
~logger
~delta:1e-3
()
in
Tensor.pp Format.std_formatter @@ stationary_distribution_of_params dif_param;
Out_channel.(newline stdout);
let newick_tree = Tree.map sim_tree
~node:(fun _ -> Newick_ast.{name=None};)
~leaf:(fun leaf -> Newick_ast.{
name=Some (Amino_acid.to_char leaf |> Char.to_string) ;
})
~branch:(fun branch -> Newick_ast.{length=Some branch; tags=[]})
|> Newick_ast.Tree
in
"Tree : " ^ Newick.to_string newick_tree |> print_endline;
Out_channel.(newline stdout);
(*
let site = Tree.leaves sim_tree |> List.to_array in
let tdg09_lik, tdg09_param = TDG09.Model2.maximum_log_likelihood
~debug:true
~mode:`dense
~exchangeability_matrix:wag.rate_matrix
(number_leaves sim_tree)
site
in
print_endline @@ sprintf "TDG09 lik = %f" tdg09_lik;
print_endline @@ sprintf "TDG09 scale = %f" tdg09_param.scale;
print_endline "TDG09 dist";
Amino_acid.Vector.pp Format.std_formatter tdg09_param.stationary_distribution; *)
......@@ -48,3 +48,11 @@
(libraries codepi)
(preprocess
(pps ppx_jane ppx_csv_conv bistro.ppx ppx_here)))
(executable
(name diffelstein_tdg)
(public_name diffelstein_tdg)
(modules diffelstein_tdg)
(libraries codepi)
(preprocess
(pps ppx_jane ppx_csv_conv bistro.ppx ppx_here)))
0.551571
0.509848 0.635346
0.738998 0.147304 5.429420
1.027040 0.528191 0.265256 0.0302949
0.908598 3.035500 1.543640 0.616783 0.0988179
1.582850 0.439157 0.947198 6.174160 0.021352 5.469470
1.416720 0.584665 1.125560 0.865584 0.306674 0.330052 0.567717
0.316954 2.137150 3.956290 0.930676 0.248972 4.294110 0.570025 0.249410
0.193335 0.186979 0.554236 0.039437 0.170135 0.113917 0.127395 0.0304501 0.138190
0.397915 0.497671 0.131528 0.0848047 0.384287 0.869489 0.154263 0.0613037 0.499462 3.170970
0.906265 5.351420 3.012010 0.479855 0.0740339 3.894900 2.584430 0.373558 0.890432 0.323832 0.257555
0.893496 0.683162 0.198221 0.103754 0.390482 1.545260 0.315124 0.174100 0.404141 4.257460 4.854020 0.934276
0.210494 0.102711 0.0961621 0.0467304 0.398020 0.0999208 0.0811339 0.049931 0.679371 1.059470 2.115170 0.088836 1.190630
1.438550 0.679489 0.195081 0.423984 0.109404 0.933372 0.682355 0.243570 0.696198 0.0999288 0.415844 0.556896 0.171329 0.161444
3.370790 1.224190 3.974230 1.071760 1.407660 1.028870 0.704939 1.341820 0.740169 0.319440 0.344739 0.967130 0.493905 0.545931 1.613280
2.121110 0.554413 2.030060 0.374866 0.512984 0.857928 0.822765 0.225833 0.473307 1.458160 0.326622 1.386980 1.516120 0.171903 0.795384 4.378020
0.113133 1.163920 0.0719167 0.129767 0.717070 0.215737 0.156557 0.336983 0.262569 0.212483 0.665309 0.137505 0.515706 1.529640 0.139405 0.523742 0.110864
0.240735 0.381533 1.086000 0.325711 0.543833 0.227710 0.196303 0.103604 3.873440 0.420170 0.398618 0.133264 0.428437 6.454280 0.216046 0.786993 0.291148 2.485390
2.006010 0.251849 0.196246 0.152335 1.002140 0.301281 0.588731 0.187247 0.118358 7.821300 1.800340 0.305434 2.058450 0.649892 0.314887 0.232739 1.388230 0.365369 0.314730
0.0866279 0.043972 0.0390894 0.0570451 0.0193078 0.0367281 0.0580589 0.0832518 0.0244313 0.048466 0.086209 0.0620286 0.0195027 0.0384319 0.0457631 0.0695179 0.0610127 0.0143859 0.0352742 0.0708956
A R N D C Q E G H I L K M F P S T W Y V
Ala Arg Asn Asp Cys Gln Glu Gly His Ile Leu Lys Met Phe Pro Ser Thr Trp Tyr Val
Symmetrical part of the WAG rate matrix and aa frequencies,
estimated from 3905 globular protein amino acid sequences forming 182
protein families.
The first part above indicates the symmetric 'exchangeability'
parameters, where s_ij = s_ji. The s_ij above are not scaled, but the
PAML package will perform this scaling.
The second part gives the amino acid frequencies (pi_i)
estimated from the 3905 sequences. The net replacement rate from i to
j is Q_ij = s_ij*pi_j.
Prepared by Simon Whelan and Nick Goldman, December 2000.
Citation:
Whelan, S. and N. Goldman. 2001. A general empirical model of
protein evolution derived from multiple protein families using
a maximum likelihood approach. Molecular Biology and
Evolution 18:691-699.
......@@ -9,25 +9,50 @@ and branch = {
tip : tree ;
}
module type Tensor = sig
type t
val one_hot : int -> num_classes:int -> t
val mul : t -> t -> t
val matmul : t -> t -> t
val dot : t -> t -> t
end
let rec conditional_likelihood tree ~param ~trans_matrix ~alphabet_size : Tensor.t =
match tree with
| Leaf l -> Tensor.(one_hot (of_int0 l) ~num_classes:alphabet_size
|> _cast_float ~non_blocking:false)
| Node (l, r) ->
Tensor.mul
(integrate_likelihood l ~param ~trans_matrix ~alphabet_size)
(integrate_likelihood r ~param ~trans_matrix ~alphabet_size)
and integrate_likelihood branch ~param ~trans_matrix ~alphabet_size : Tensor.t =
Tensor.matmul
(trans_matrix branch.length param)
(conditional_likelihood branch.tip ~param ~trans_matrix ~alphabet_size)
let pruning tree ~param ~trans_matrix ~stationary_distribution ~alphabet_size =
Tensor.dot
(stationary_distribution param)
(conditional_likelihood tree ~param ~trans_matrix ~alphabet_size)
module Pruning(Tensor : Tensor) = struct
let rec conditional_likelihood tree ~param ~trans_matrix ~alphabet_size : Tensor.t =
match tree with
| Leaf l -> Tensor.one_hot l ~num_classes:alphabet_size
| Node (l, r) ->
Tensor.mul
(integrate_likelihood l ~param ~trans_matrix ~alphabet_size)
(integrate_likelihood r ~param ~trans_matrix ~alphabet_size)
and integrate_likelihood branch ~param ~trans_matrix ~alphabet_size : Tensor.t =
Tensor.matmul
(trans_matrix branch.length param)
(conditional_likelihood branch.tip ~param ~trans_matrix ~alphabet_size)
let pruning tree ~param ~trans_matrix ~stationary_distribution ~alphabet_size =
Tensor.dot
(stationary_distribution param)
(conditional_likelihood tree ~param ~trans_matrix ~alphabet_size)
end
module Pytorch_tensor = struct
type t = Tensor.t
let one_hot i ~num_classes =
Array.init num_classes ~f:(fun x -> if x = i then 1. else 0.)
|> Tensor.of_float1
(* Tensor.(one_hot (of_int0 i) ~num_classes
|> _cast_float ~non_blocking:false) *)
let matmul = Tensor.matmul
let dot = Tensor.dot
let mul = Tensor.mul
end
include Pruning(Pytorch_tensor)
let rec likelihood_optim tree
~param ~trans_matrix ~stationary_distribution ~alphabet_size
......@@ -47,9 +72,38 @@ let rec likelihood_optim tree
match last_likelihood with
| Some last_lik ->
let diff = Float.(log likelihood_f -. log last_lik) in
let lr = if Float.(diff < 0.) then lr /. 2. else lr *. 1.1 in
let lr = if Float.(diff < 0.) then lr *. 0.9 else lr *. 1.01 in
if Float.(lr < delta) || Float.is_nan (Tensor.to_float0_exn likelihood)
then param
else optimize ~lr
| None -> optimize ~lr
module Lacaml_tensor = struct
module L = Phylogenetics.Linear_algebra.Lacaml
type t =
| Float of float
| Vector of L.vec
| Matrix of L.mat
let one_hot i ~num_classes =
Vector (L.Vector.init num_classes ~f:(fun j -> if i = j then 1. else 0.))
let mul x y =
match x, y with
| Float x, Float y -> Float (x *. y)
| Vector x, Vector y -> Vector (L.Vector.mul x y)
| Matrix x, Matrix y -> Matrix (L.Matrix.mul x y)
| _ -> invalid_arg "mul"
let dot x y =
match x, y with
| Vector x, Vector y -> Float (L.Vector.sum (L.Vector.mul x y))
| _ -> invalid_arg "dot"
let matmul x y =
match x, y with
| Matrix x, Vector y -> Vector (L.Matrix.apply x y)
| _ -> invalid_arg "matmul"
end
module Lacaml_pruning = Pruning(Lacaml_tensor)
\ No newline at end of file
......@@ -3,35 +3,47 @@ open Torch
type tree = Node of branch * branch | Leaf of int
and branch = { length : float; tip : tree; }
module type Tensor = sig
type t
val one_hot : int -> num_classes:int -> t
val mul : t -> t -> t
val matmul : t -> t -> t
val dot : t -> t -> t
end
val conditional_likelihood :
tree ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
alphabet_size: int ->
Tensor.t
module Pruning(Tensor : Tensor) : sig
val conditional_likelihood :
tree ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
alphabet_size: int ->
Tensor.t
val integrate_likelihood :
branch ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
alphabet_size: int ->
Tensor.t
val integrate_likelihood :
branch ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
alphabet_size: int ->
Tensor.t
val pruning :
tree ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
val pruning :
tree ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
stationary_distribution: ('a -> Tensor.t) ->
alphabet_size: int ->
Tensor.t
alphabet_size: int ->
Tensor.t
end
module Pytorch_tensor : Tensor with type t = Torch.Tensor.t
include module type of Pruning(Pytorch_tensor)
val likelihood_optim :
tree ->
param:'a ->
trans_matrix:(float -> 'a -> Tensor.t) ->
stationary_distribution:('a -> Tensor.t) ->
alphabet_size: int ->
alphabet_size: int ->
lr:float ->
steps:int ->
?last_likelihood:float ->
......@@ -40,3 +52,14 @@ val likelihood_optim :
?delta: float ->
unit ->
'a
module Lacaml_tensor : sig
module L = Phylogenetics.Linear_algebra.Lacaml
type t =
| Float of float
| Vector of L.vec
| Matrix of L.mat
include Tensor with type t := t
end
module Lacaml_pruning : module type of Pruning(Lacaml_tensor)
\ No newline at end of file
......@@ -31,6 +31,7 @@ let newick_str =
"(t2:2.400985762,(t5:2.20248759,(t1:0.0700889683,t3:0.0700889683):2.132398622):0.1984981724);"
(* "((t2:2.400985762,(t5:2.20248759,(t1:0.0700889683,t3:0.0700889683):2.132398622):0.1984981724):45.82815443,(t4:2.417988309,t6:2.417988309):45.81115189):3.793640769;" *)
let () = Gsl.Rng.set_default_seed @@ Nativeint.of_int 0;;
let rng = Gsl.Rng.(make (default ()))
let tree = Newick.from_string_exn newick_str
let scale = 2.
......@@ -52,6 +53,7 @@ type diffparam = {
let param_of_wag (wag:Wag.t) scale = {
scale=Tensor.f scale |> Tensor.set_requires_grad ~r:true;
stationary_distribution =
(* *)
(* Amino_acid.Vector.to_array wag.freqs *)
Array.init Amino_acid.card ~f:(fun _ -> 0.)
|> Tensor.of_float1
......@@ -102,6 +104,8 @@ let trans_matrix branch_length param =
let update_param param ~lr =
Tensor.(
param.stationary_distribution += grad param.stationary_distribution * f lr;
(* let new_scale = param.scale + grad param.scale * f lr in
if Float.(to_float0_exn new_scale > 0.) then *)
param.scale += grad param.scale * f lr;
zero_grad param.stationary_distribution;
zero_grad param.scale;
......@@ -147,7 +151,7 @@ let () =
let site = Tree.leaves sim_tree |> List.to_array in
let tdg09_lik, tdg09_param = TDG09.Model2.maximum_log_likelihood
~debug:true
(* ~debug:true *)
~mode:`dense
~exchangeability_matrix:wag.rate_matrix
(number_leaves sim_tree)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment