Commit 0857d270 authored by Philippe Veber's avatar Philippe Veber
Browse files

prc: new (and wrong) implementation of average_position + tests

parent a7f58c54
......@@ -126,21 +126,18 @@ module Precision_recall = struct
Sexplib.Sexp.pp Format.std_formatter ([%sexp_of: (float * int * int) list] groups) ;
[%expect "((0 2 2)(1 0 3)(2 3 0))"]
let auc_average_precision (Dataset xs) =
let pos, _, sum =
List.sort xs ~compare:(Tuple.T2.compare ~cmp1:Float.descending ~cmp2:Bool.descending)
|> List.fold ~init:(0, 0, 0.) ~f:(fun (pos, neg, sum) (_, b) ->
let pos, neg = if b then pos + 1, neg else pos, neg +1 in
let sum =
if b then
let precision = float pos /. float (pos + neg) in
precision +. sum
else sum
in
(pos, neg, sum)
let auc_average_precision (Dataset xs as d) =
if List.is_empty xs then invalid_arg "average_precision estimator undefined on empty lists" ;
let npos, _, sum =
decreasing_ties_groups d
|> List.fold ~init:(0, 0, 0.) ~f:(fun (npos, nneg, sum) (_, p, n) ->
let npos = npos + p and nneg = nneg + n in
let prec = float npos /. float (npos + nneg) in
let sum = float p *. prec +. sum in
npos, nneg, sum
)
in
sum /. float pos
sum /. float npos
let logit_confidence_interval ~alpha ~theta_hat ~n =
let eta_hat = logit theta_hat in
......
(tests
(names prc_test)
(libraries alcotest prc)
(preprocess (pps ppx_jane ppx_deriving.show)))
open Core_kernel
open Prc
let floats_are_close_enough x y = Float.(abs (x -. y) < 1e-6)
module Naive_implementation = struct
let tp (Dataset xs) c = List.count xs ~f:(fun (x, b) -> b && Float.(x >= c))
let _tn (Dataset xs) c = List.count xs ~f:(fun (x, b) -> not b && Float.(x < c))
let fp (Dataset xs) c = List.count xs ~f:(fun (x, b) -> not b && Float.(x >= c))
let fn (Dataset xs) c = List.count xs ~f:(fun (x, b) -> b && Float.(x < c))
let _recall d c =
let tp = tp d c in
float tp /. float (tp + fn d c)
let precision d c =
let tp = tp d c in
float tp /. float (tp + fp d c)
let average_precision (Dataset xs as d) =
if List.is_empty xs then invalid_arg "average_precision estimator undefined on empty lists" ;
let ys = List.map xs ~f:fst in
let n = List.length xs in
List.fold ys ~init:0. ~f:(fun acc y -> precision d y +. acc)
/. float n
end
let binormal_generator ~n rng =
let alpha = Gsl.Randist.flat rng ~a:0. ~b:1. in
let model = Binormal_model.make alpha in
Binormal_model.simulation rng ~n model
let saturated_binormal_generator ~n rng =
let clip a b x = Float.(if x < a then a else if x > b then b else x) in
let Dataset xs = binormal_generator ~n rng in
Dataset (List.map xs ~f:(fun (x, b) -> clip (-0.5) 1.5 x, b))
let test_implementations ~gen ~n ~show ~equal f g () =
let rng = Gsl.Rng.(make (default ())) in
let rec loop i =
if i < n then
let x = gen rng in
if not (equal (f x) (g x)) then (
sprintf "Implementations differ on %s" (show x)
|> Alcotest.fail
)
else loop (i + 1)
in
loop 0
type dataset = Prc.dataset = Dataset of (float * bool) list
[@@deriving show]
let average_precision_test gen =
test_implementations
~gen ~show:show_dataset ~n:100 ~equal:floats_are_close_enough
Precision_recall.auc_average_precision
Naive_implementation.average_precision
let tests = [
"average_precision on binormal model 3", `Quick, average_precision_test (binormal_generator ~n:3) ;
]
let () =
Alcotest.run "prc" [
"test-against-naive-implementation", tests
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment