Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
VEBER Philippe
codepi
Commits
dd0615ff
Commit
dd0615ff
authored
Nov 30, 2020
by
Philippe Veber
Browse files
prc: improved average_precision implementation and fixed tests
parent
0857d270
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
9 deletions
+24
-9
lib/prc/prc.ml
lib/prc/prc.ml
+12
-4
tests/prc_test.ml
tests/prc_test.ml
+12
-5
No files found.
lib/prc/prc.ml
View file @
dd0615ff
...
@@ -116,15 +116,23 @@ module Precision_recall = struct
...
@@ -116,15 +116,23 @@ module Precision_recall = struct
let
npos'
,
nneg'
=
if
b
then
1
,
0
else
0
,
1
in
let
npos'
,
nneg'
=
if
b
then
1
,
0
else
0
,
1
in
loop
((
s
,
npos
,
nneg
)
::
closed_groups
)
(
Some
(
s'
,
npos'
,
nneg'
))
t
loop
((
s
,
npos
,
nneg
)
::
closed_groups
)
(
Some
(
s'
,
npos'
,
nneg'
))
t
in
in
let
compare
=
Tuple
.
T2
.
compare
~
cmp1
:
Float
.
de
scending
~
cmp2
:
Bool
.
de
scending
in
let
compare
=
Tuple
.
T2
.
compare
~
cmp1
:
Float
.
a
scending
~
cmp2
:
Bool
.
a
scending
in
List
.
sort
xs
~
compare
List
.
sort
xs
~
compare
|>
loop
[]
None
|>
loop
[]
None
let
decreasing_ties_groups_test
d
=
let
groups
=
decreasing_ties_groups
d
in
Sexplib
.
Sexp
.
pp
Format
.
std_formatter
([
%
sexp_of
:
(
float
*
int
*
int
)
list
]
groups
)
let
%
expect_test
"decreasing_ties_groups"
=
let
%
expect_test
"decreasing_ties_groups"
=
let
data
=
Dataset
[
0
.,
false
;
1
.,
false
;
2
.,
true
;
0
.,
false
;
1
.,
false
;
2
.,
true
;
0
.,
true
;
1
.,
false
;
2
.,
true
;
0
.,
true
]
in
let
data
=
Dataset
[
0
.,
false
;
1
.,
false
;
2
.,
true
;
0
.,
false
;
1
.,
false
;
2
.,
true
;
0
.,
true
;
1
.,
false
;
2
.,
true
;
0
.,
true
]
in
let
groups
=
decreasing_ties_groups
data
in
decreasing_ties_groups_test
data
;
Sexplib
.
Sexp
.
pp
Format
.
std_formatter
([
%
sexp_of
:
(
float
*
int
*
int
)
list
]
groups
)
;
[
%
expect
"((2 3 0)(1 0 3)(0 2 2))"
]
[
%
expect
"((0 2 2)(1 0 3)(2 3 0))"
]
let
%
expect_test
"decreasing_ties_groups 2"
=
let
data
=
Dataset
[(
1
.
0070885317
,
true
);
(
0
.
297475057516
,
false
);
(
0
.
831050790341
,
false
)]
in
decreasing_ties_groups_test
data
;
[
%
expect
"((1.0070885317 1 0)(0.831050790341 0 1)(0.297475057516 0 1))"
]
let
auc_average_precision
(
Dataset
xs
as
d
)
=
let
auc_average_precision
(
Dataset
xs
as
d
)
=
if
List
.
is_empty
xs
then
invalid_arg
"average_precision estimator undefined on empty lists"
;
if
List
.
is_empty
xs
then
invalid_arg
"average_precision estimator undefined on empty lists"
;
...
...
tests/prc_test.ml
View file @
dd0615ff
...
@@ -16,16 +16,18 @@ module Naive_implementation = struct
...
@@ -16,16 +16,18 @@ module Naive_implementation = struct
float
tp
/.
float
(
tp
+
fp
d
c
)
float
tp
/.
float
(
tp
+
fp
d
c
)
let
average_precision
(
Dataset
xs
as
d
)
=
let
average_precision
(
Dataset
xs
as
d
)
=
if
List
.
is_empty
xs
then
invalid_arg
"average_precision estimator undefined on empty lists"
;
if
List
.
is_empty
xs
then
invalid_arg
"average_precision estimator undefined on empty lists"
;
let
ys
=
List
.
map
xs
~
f
:
fst
in
let
ys
=
List
.
filter_
map
xs
~
f
:
(
fun
(
x
,
b
)
->
if
b
then
Some
x
else
None
)
in
let
n
=
List
.
length
x
s
in
let
n
=
List
.
length
y
s
in
List
.
fold
ys
~
init
:
0
.
~
f
:
(
fun
acc
y
->
precision
d
y
+.
acc
)
List
.
fold
ys
~
init
:
0
.
~
f
:
(
fun
acc
y
->
precision
d
y
+.
acc
)
/.
float
n
/.
float
n
end
end
let
binormal_generator
~
n
rng
=
let
rec
binormal_generator
~
n
rng
=
let
alpha
=
Gsl
.
Randist
.
flat
rng
~
a
:
0
.
~
b
:
1
.
in
let
alpha
=
Gsl
.
Randist
.
flat
rng
~
a
:
0
.
1
~
b
:
1
.
in
let
model
=
Binormal_model
.
make
alpha
in
let
model
=
Binormal_model
.
make
alpha
in
Binormal_model
.
simulation
rng
~
n
model
let
Dataset
xs
as
d
=
Binormal_model
.
simulation
rng
~
n
model
in
if
List
.
exists
xs
~
f
:
snd
then
d
else
binormal_generator
~
n
rng
let
saturated_binormal_generator
~
n
rng
=
let
saturated_binormal_generator
~
n
rng
=
let
clip
a
b
x
=
Float
.(
if
x
<
a
then
a
else
if
x
>
b
then
b
else
x
)
in
let
clip
a
b
x
=
Float
.(
if
x
<
a
then
a
else
if
x
>
b
then
b
else
x
)
in
...
@@ -56,6 +58,11 @@ let average_precision_test gen =
...
@@ -56,6 +58,11 @@ let average_precision_test gen =
let
tests
=
[
let
tests
=
[
"average_precision on binormal model 3"
,
`Quick
,
average_precision_test
(
binormal_generator
~
n
:
3
)
;
"average_precision on binormal model 3"
,
`Quick
,
average_precision_test
(
binormal_generator
~
n
:
3
)
;
"average_precision on binormal model 10"
,
`Quick
,
average_precision_test
(
binormal_generator
~
n
:
10
)
;
"average_precision on binormal model 100"
,
`Quick
,
average_precision_test
(
binormal_generator
~
n
:
100
)
;
"average_precision on saturated binormal model 3"
,
`Quick
,
average_precision_test
(
saturated_binormal_generator
~
n
:
3
)
;
"average_precision on saturated binormal model 10"
,
`Quick
,
average_precision_test
(
saturated_binormal_generator
~
n
:
10
)
;
"average_precision on saturated binormal model 100"
,
`Quick
,
average_precision_test
(
saturated_binormal_generator
~
n
:
100
)
;
]
]
let
()
=
let
()
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment