convergence_tree.ml 7.42 KB
Newer Older
1 2 3
open Core_kernel
open Phylogenetics

4 5 6 7 8 9 10
type condition = [`Ancestral | `Convergent]

type branch_info = {
  condition : condition ;
  length : float ;
}

11 12 13 14 15 16
module Branch_info = struct
  type t = branch_info
  let length (bi : t) = bi.length
  let condition (bi : t) = bi.condition
end

17
type t = (unit, string, branch_info) Tree.t
18 19 20

module Tags = struct
  let condition_label = "Condition"
21

22 23
  let transition_label = "Transition"

24
  let condition tags =
25 26
    List.Assoc.find tags condition_label ~equal:String.equal

27 28 29 30
  let string_of_condition = function
    | `Ancestral -> "0"
    | `Convergent -> "1"

31 32 33 34
  let set_condition tags c =
    List.Assoc.(
      add
        (remove tags condition_label ~equal:String.equal)
35
        condition_label c ~equal:String.equal)
36 37 38 39 40 41 42 43 44

  (* let other_tags tags =
   *   List.filter tags ~f:(fun (key, _) -> String.(key <> condition_label && key <> transition_label)) *)

  let unset_transition tags =
    List.Assoc.remove tags transition_label ~equal:String.equal

  let set_transition tags c =
    List.Assoc.(
45
      add (unset_transition tags) transition_label c ~equal:String.equal)
46 47 48 49 50
end

let condition_of_branch_info (bi : Newick.branch_info) =
  Tags.condition bi.tags

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
let of_newick_tree t =
  let open Phylogenetics.Newick in
  try
    let node _ = () in
    let leaf (l : Newick.node_info) = match l.name with
      | Some n -> n
      | None -> failwith "missing leaf name"
    in
    let branch b =
      let length = match b.length with
        | None -> failwith "missing branch length"
        | Some bl -> bl
      in
      let condition =
        match
          List.Assoc.find ~equal:String.equal b.tags "Condition"
        with
        | Some s -> (
            match s with
            | "0" -> `Ancestral
            | "1" -> `Convergent
            | _ -> failwithf "Invalid condition: %s" s () )
        | None -> failwith "Missing Condition tag"
      in
      { length; condition }
    in
    Tree.map t ~node ~leaf ~branch
    |> Result.return
  with Failure msg -> Result.fail (`Msg msg)
80

81 82 83 84 85 86 87 88 89 90
let to_newick_tree t =
  let node () = { Newick.name = None } in
  let leaf l = { Newick.name = Some l } in
  let branch b = {
    Newick.length = Some b.length ;
    tags = []
  }
  in
  Tree.map t ~node ~leaf ~branch

91 92
let from_file fn =
  Newick.from_file fn
93
  |> Newick.with_inner_tree ~f:of_newick_tree
94 95 96 97 98

let leaves tree =
  let rec node condition t acc =
    match t with
    | Tree.Node n -> List1.fold_right n.branches ~init:acc ~f:branch
99
    | Leaf species -> (species, condition) :: acc
100
  and branch (Tree.Branch b) acc =
101
    node b.data.condition b.tip acc
102 103 104
  in
  node `Ancestral tree []

105 106 107 108 109 110 111 112
let%test "leaves computed in correct order" =
  match from_file "../../data/besnard2009/besnard2009.nhx" with
  | Ok t ->
    List.equal String.equal
      (List.take (leaves t) 4 |> List.map ~f:fst)
      [ "Chrysithr" ; "Ele.bald" ; "Ele.bal2" ; "Ele.bal4" ]
  | Error (`Msg msg) -> failwith msg

113
let rec transfer_condition_to_branches t =
114
  let category : _ Tree.t -> condition = function
115 116 117 118 119 120
    | Leaf (_, c) -> c
    | Node n -> snd n.data
  in
  match t with
  | Tree.Leaf (l, _) -> Tree.leaf l
  | Node n ->
121 122 123
      List1.map n.branches ~f:(fun (Branch b) ->
          let cat_child = category b.tip in
          let tags =
124
            Tags.set_condition b.data.Newick.tags (Tags.string_of_condition cat_child)
125 126 127 128 129
          in
          Tree.branch
            { b.data with Newick.tags }
            (transfer_condition_to_branches b.tip))
      |> Tree.node (fst n.data)
130

131
let reset_transitions (tree : Newick.tree) =
132
  let rec aux mother_condition tree =
133
    match (tree : Newick.tree) with
134 135
    | Leaf _ as l -> l
    | Node n ->
136 137 138 139 140 141 142 143 144 145 146 147
        let branches =
          List1.map n.branches ~f:(fun (Branch b) ->
              let tags, c_b =
                match Tags.condition b.data.tags with
                | None -> failwith "tree tagged with condition expected"
                | Some c_b ->
                    let tags =
                      if String.(c_b <> mother_condition) then
                        Tags.set_transition b.data.tags c_b
                      else Tags.unset_transition b.data.tags
                    in
                    (tags, c_b)
148
              in
149 150 151 152
              let data = { b.data with tags } in
              Tree.branch data (aux c_b b.tip))
        in
        Node { n with branches }
153 154 155 156
  in
  match tree with
  | Leaf _ as l -> l
  | Node n ->
157 158 159 160 161 162 163
      let branches =
        List1.map n.branches ~f:(fun (Branch b) ->
            match Tags.condition b.data.tags with
            | None -> failwith "tree tagged with condition expected"
            | Some c_b -> Tree.branch b.data (aux c_b b.tip))
      in
      Node { n with branches }
164 165 166 167 168

let length_on_each_condition branches =
  let module A = Biocaml_unix.Accu in
  let acc = A.create ~bin:Fn.id ~zero:0. ~add:( +. ) () in
  List.iter branches ~f:(fun bi ->
169
      match (condition_of_branch_info bi, bi.Newick.length) with
170
      | Some c, Some l -> A.add acc c l
171
      | _ -> ()) ;
172 173
  A.to_alist acc

174
let remove_nodes_with_single_child (tree : Newick.tree) =
175 176
  Tree.simplify_node_with_single_child tree
    ~merge_branch_data:(fun branches ->
177 178 179
      let condition_stats = length_on_each_condition branches in
      let major_condition =
        List.max_elt condition_stats ~compare:(fun (_, l) (_, l') ->
180
            Float.compare l l')
181
      in
182 183
      let tags =
        match major_condition with
184 185 186
        | None -> []
        | Some (c, _) -> Tags.set_condition [] c
      in
187 188 189 190 191
      let length =
        List.fold branches ~init:0. ~f:(fun acc bi ->
            acc +. Option.value_exn bi.length)
      in
      { Newick.tags; length = Some length })
192 193
  |> reset_transitions

194 195
let infer_binary_condition_on_branches ?(gain_relative_cost = 2.) t
    ~convergent_leaves =
196 197
  let category (ni : Newick.node_info) =
    Option.map ni.name ~f:(fun l ->
198
        if String.Set.mem convergent_leaves l then 1 else 0)
199 200
  in
  let cost x y =
201
    match (x, y) with
202 203
    | 0, 1 -> gain_relative_cost
    | 1, 0 -> 1.
204
    | 0, 0 | 1, 1 -> 0.
205 206
    | _ -> assert false
  in
207 208 209 210 211
  let convert_node = function
    | x, 0 -> x, `Ancestral
    | x, 1 -> x, `Convergent
    | _ -> assert false
  in
212
  Fitch.fitch ~cost ~n:2 ~category t
213
  |> Tree.map ~node:convert_node ~leaf:convert_node ~branch:Fn.id
214
  |> transfer_condition_to_branches |> reset_transitions
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239


let alignment_counts_map tree alignment f =
  let leaves =
    leaves tree
    |> List.map ~f:(fun (n, cond) ->
        match Alignment.find_sequence alignment n with
        | None -> failwithf "Could not find %s in alignment" n ()
        | Some seq -> seq, cond
      )
  in
  let seqs0, seqs1 = List.partition_map leaves ~f:Either.(function
      | (aa, `Ancestral) -> First aa
      | (aa, `Convergent) -> Second aa
    )
  in
  let counts seqs i =
    Amino_acid.Table.init (fun aa ->
        let aa = Amino_acid.to_char aa in
        List.count seqs ~f:(fun s -> Char.equal s.[i] aa)
      )
  in
  let site i = f (counts seqs0 i) (counts seqs1 i) in
  let n = Alignment.ncols alignment in
  List.init n ~f:site
240 241 242 243 244 245 246 247 248 249 250 251

let pair_tree ~node_info ~leaf_info ~branch_length1 ~branch_length2 ~npairs =
  let leaf i cond = Tree.Leaf (leaf_info i cond) in
  let branch length condition tip = Tree.branch { length ; condition } tip in
  let tree = Tree.binary_node node_info in
  let make_pair i =
    tree
      (branch branch_length2 `Ancestral (leaf (2 * i) `Ancestral))
      (branch branch_length2 `Convergent (leaf (2 * i + 1) `Convergent))
    |> branch branch_length1 `Ancestral
  in
  Tree.node node_info  (List1.init npairs ~f:make_pair)