inhouse_lmm.ml 2.66 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
open Core_kernel
open Phylogenetics
module L = Lacaml.D
module BI = Phylogenetics_convergence.Simulator.Branch_info

type correlations = (string * string * float) list * String.Set.t

let merge_correlations time_from_ancestor ((dist_l, l) : correlations)
    ((dist_r, r) : correlations) =
  let dist_lr =
    String.Set.fold l ~init:[] ~f:(fun acc e ->
        String.Set.fold r ~init:acc ~f:(fun acc f ->
            (e, f, time_from_ancestor) :: acc))
  in
  (List.concat [ dist_l; dist_r; dist_lr ], String.Set.union l r)

let correlations (t : Convergence_tree.u) : (string * string * float) list
    =
  let rec tree time_from_ancestor = function
    | Tree.Leaf l ->
        let l = Option.value_exn l.Newick.name in
        ([ (l, l, time_from_ancestor) ], String.Set.singleton l)
    | Node n ->
        List1.map n.branches ~f:(branch time_from_ancestor)
        |> List1.reduce ~f:(merge_correlations time_from_ancestor)
  and branch time_from_ancestor (Branch b) =
    tree (time_from_ancestor +. b.data.BI.length) b.tip
  in
  fst (tree 0. t)

let design_matrix ~m ~n ~aa_at_site (al : Alignment.t) =
  L.Mat.init_rows m n (fun i j ->
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
      if j = 0 then 1.
      else if Char.(al.sequences.(i).[j - 1] = aa_at_site.(j - 1)) then 1.
      else 0.)

let predict_y ~_X_ ~theta = L.gemv _X_ theta

let squares_sum ~y_r ~y_q = L.Vec.sub y_r y_q |> L.Vec.sqr_nrm2

let f_stat ~y ~_X_r ~theta_r ~_X_q ~theta_q =
  let y_r = predict_y ~_X_:_X_r ~theta:theta_r in
  let y_q = predict_y ~_X_:_X_q ~theta:theta_q in
  let scm_rq = squares_sum ~y_r ~y_q in
  let scr_r = squares_sum ~y_r ~y_q:y in
  (* assuming X has full rank since we are solving using LM equations *)
  let rank_r = Float.of_int (L.Mat.dim2 _X_r) in
  let rank_q = 1. in
  let n = Float.of_int (L.Vec.dim y) in
  scm_rq /. (rank_r -. rank_q) /. (scr_r /. (n -. rank_r))
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

let solve ~y ~_X_ ~_T_ =
  let _Xtilde_ = L.Mat.mul _X_ _T_ in
  let ytilde = L.gemv _T_ y in
  L.gemm
    (L.getri (L.gemm ~transa:`T _Xtilde_ _Xtilde_))
    (L.gemv ~trans:`T _Xtilde_ ytilde)

let lrt_on_one_site ~alignment:al ~phenotypes:_ ~_C_ ~site =
  let m = Alignment.nrows al in
  let aa_at_site =
    Alignment.residues al ~column:site |> Char.Set.to_list |> Array.of_list
  in
  let n = Array.length aa_at_site in
  let _X_ = design_matrix ~m ~n ~aa_at_site al in
  ()

let phenotypes_of_tree t =
  Convergence_tree.leaves t
  |> List.map ~f:(fun (_, condition) ->
         match condition with `Ancestral -> 0. | `Convergent -> 1.)
  |> Array.of_list |> L.Vec.of_array

let lrt ~alignment ~tree =
  let phenotypes = phenotypes_of_tree tree in
  Array.init (Alignment.ncols alignment) ~f:(fun site ->
      lrt_on_one_site ~alignment ~phenotypes ~site)
  |> assert false