Commit dd0615ff authored by Philippe Veber's avatar Philippe Veber
Browse files

prc: improved average_precision implementation and fixed tests

parent 0857d270
......@@ -116,15 +116,23 @@ module Precision_recall = struct
let npos', nneg' = if b then 1, 0 else 0, 1 in
loop ((s, npos, nneg) :: closed_groups) (Some (s', npos', nneg')) t
in
let compare = Tuple.T2.compare ~cmp1:Float.descending ~cmp2:Bool.descending in
let compare = Tuple.T2.compare ~cmp1:Float.ascending ~cmp2:Bool.ascending in
List.sort xs ~compare
|> loop [] None
let decreasing_ties_groups_test d =
let groups = decreasing_ties_groups d in
Sexplib.Sexp.pp Format.std_formatter ([%sexp_of: (float * int * int) list] groups)
let%expect_test "decreasing_ties_groups" =
let data = Dataset [0., false; 1., false; 2., true; 0., false; 1., false; 2., true; 0., true; 1., false; 2., true; 0., true] in
let groups = decreasing_ties_groups data in
Sexplib.Sexp.pp Format.std_formatter ([%sexp_of: (float * int * int) list] groups) ;
[%expect "((0 2 2)(1 0 3)(2 3 0))"]
decreasing_ties_groups_test data ;
[%expect "((2 3 0)(1 0 3)(0 2 2))"]
let%expect_test "decreasing_ties_groups 2" =
let data = Dataset [(1.0070885317, true); (0.297475057516, false); (0.831050790341, false)] in
decreasing_ties_groups_test data ;
[%expect "((1.0070885317 1 0)(0.831050790341 0 1)(0.297475057516 0 1))"]
let auc_average_precision (Dataset xs as d) =
if List.is_empty xs then invalid_arg "average_precision estimator undefined on empty lists" ;
......
......@@ -16,16 +16,18 @@ module Naive_implementation = struct
float tp /. float (tp + fp d c)
let average_precision (Dataset xs as d) =
if List.is_empty xs then invalid_arg "average_precision estimator undefined on empty lists" ;
let ys = List.map xs ~f:fst in
let n = List.length xs in
let ys = List.filter_map xs ~f:(fun (x, b) -> if b then Some x else None) in
let n = List.length ys in
List.fold ys ~init:0. ~f:(fun acc y -> precision d y +. acc)
/. float n
end
let binormal_generator ~n rng =
let alpha = Gsl.Randist.flat rng ~a:0. ~b:1. in
let rec binormal_generator ~n rng =
let alpha = Gsl.Randist.flat rng ~a:0.1 ~b:1. in
let model = Binormal_model.make alpha in
Binormal_model.simulation rng ~n model
let Dataset xs as d = Binormal_model.simulation rng ~n model in
if List.exists xs ~f:snd then d
else binormal_generator ~n rng
let saturated_binormal_generator ~n rng =
let clip a b x = Float.(if x < a then a else if x > b then b else x) in
......@@ -56,6 +58,11 @@ let average_precision_test gen =
let tests = [
"average_precision on binormal model 3", `Quick, average_precision_test (binormal_generator ~n:3) ;
"average_precision on binormal model 10", `Quick, average_precision_test (binormal_generator ~n:10) ;
"average_precision on binormal model 100", `Quick, average_precision_test (binormal_generator ~n:100) ;
"average_precision on saturated binormal model 3", `Quick, average_precision_test (saturated_binormal_generator ~n:3) ;
"average_precision on saturated binormal model 10", `Quick, average_precision_test (saturated_binormal_generator ~n:10) ;
"average_precision on saturated binormal model 100", `Quick, average_precision_test (saturated_binormal_generator ~n:100) ;
]
let () =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment