Commit 70f7f82d authored by Philippe Veber's avatar Philippe Veber
Browse files

prc: added tests against sklearn

parent 61dcd710
(tests
(names prc_test)
(libraries alcotest prc)
(libraries alcotest prc sklearn)
(modes native)
(preprocess (pps ppx_jane ppx_deriving.show)))
......@@ -22,6 +22,13 @@ module Naive_implementation = struct
/. float n
end
module Sklearn_implementation = struct
let average_precision (Dataset xs) =
let y_score = List.map xs ~f:fst |> Array.of_list |> Np.Numpy.vectorf in
let y_true = List.map xs ~f:(fun (_, b) -> Bool.to_int b) |> Array.of_list |> Np.Numpy.vectori in
Sklearn.Metrics.average_precision_score ~y_score ~y_true ()
end
let rec binormal_generator ~n rng =
let alpha = Gsl.Randist.flat rng ~a:0.1 ~b:1. in
let model = Binormal_model.make alpha in
......@@ -56,16 +63,26 @@ let average_precision_test gen =
Precision_recall.auc_average_precision
Naive_implementation.average_precision
let tests = [
"average_precision on binormal model 3", `Quick, average_precision_test (binormal_generator ~n:3) ;
"average_precision on binormal model 10", `Quick, average_precision_test (binormal_generator ~n:10) ;
"average_precision on binormal model 100", `Quick, average_precision_test (binormal_generator ~n:100) ;
"average_precision on saturated binormal model 3", `Quick, average_precision_test (saturated_binormal_generator ~n:3) ;
"average_precision on saturated binormal model 10", `Quick, average_precision_test (saturated_binormal_generator ~n:10) ;
"average_precision on saturated binormal model 100", `Quick, average_precision_test (saturated_binormal_generator ~n:100) ;
let average_precision_sklearn_test gen =
test_implementations
~gen ~show:show_dataset ~n:100 ~equal:floats_are_close_enough
Precision_recall.auc_average_precision
Naive_implementation.average_precision
let average_precision_test_suite f = [
"average_precision on binormal model 3", `Quick, f (binormal_generator ~n:3) ;
"average_precision on binormal model 10", `Quick, f (binormal_generator ~n:10) ;
"average_precision on binormal model 100", `Quick, f (binormal_generator ~n:100) ;
"average_precision on saturated binormal model 3", `Quick, f (saturated_binormal_generator ~n:3) ;
"average_precision on saturated binormal model 10", `Quick, f (saturated_binormal_generator ~n:10) ;
"average_precision on saturated binormal model 100", `Quick, f (saturated_binormal_generator ~n:100) ;
]
let naive_implementation_suite = average_precision_test_suite average_precision_test
let sklearn_suite = average_precision_test_suite average_precision_sklearn_test
let () =
Alcotest.run "prc" [
"test-against-naive-implementation", tests
"test-against-naive-implementation", naive_implementation_suite ;
"test-against-sklearn", sklearn_suite ;
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment