mdocekal commited on
Commit
3fcf30e
·
1 Parent(s): 7cf89e7

zero cardinality values

Browse files
app.py CHANGED
File without changes
multi_label_precision_recall_accuracy_fscore.py CHANGED
@@ -69,6 +69,7 @@ Examples:
69
  "accuracy": 1.0,
70
  "fscore": 1.0
71
  }
 
72
  """
73
 
74
 
@@ -81,6 +82,8 @@ class MultiLabelPrecisionRecallAccuracyFscore(evaluate.Metric):
81
  def __init__(self, *args, **kwargs):
82
  super().__init__(*args, **kwargs)
83
  self.beta = kwargs.get("beta", 1.0)
 
 
84
  self.use_multiset = self.config_name == "multiset"
85
 
86
  def _info(self):
@@ -126,9 +129,9 @@ class MultiLabelPrecisionRecallAccuracyFscore(evaluate.Metric):
126
  prediction_cardinality = len(prediction)
127
  reference_cardinality = len(reference)
128
 
129
- precision = intersection_cardinality / prediction_cardinality if prediction_cardinality > 0 else 0
130
- recall = intersection_cardinality / reference_cardinality if reference_cardinality > 0 else 0
131
- accuracy = intersection_cardinality / union_cardinality if union_cardinality > 0 else 0
132
 
133
  return precision, recall, accuracy
134
 
 
69
  "accuracy": 1.0,
70
  "fscore": 1.0
71
  }
72
+
73
  """
74
 
75
 
 
82
  def __init__(self, *args, **kwargs):
83
  super().__init__(*args, **kwargs)
84
  self.beta = kwargs.get("beta", 1.0)
85
+ self.zero_cardinality_precision = kwargs.get("zero_cardinality_precision", 0.0) # default value for precision when prediction is empty, when precision and recall are both 0, it is always 1
86
+ self.zero_cardinality_recall = kwargs.get("zero_cardinality_recall", 0.0) # default value for recall when reference is empty, when precision and recall are both 0, it is always 1
87
  self.use_multiset = self.config_name == "multiset"
88
 
89
  def _info(self):
 
129
  prediction_cardinality = len(prediction)
130
  reference_cardinality = len(reference)
131
 
132
+ precision = intersection_cardinality / prediction_cardinality if prediction_cardinality > 0 else self.zero_cardinality_precision
133
+ recall = intersection_cardinality / reference_cardinality if reference_cardinality > 0 else self.zero_cardinality_recall
134
+ accuracy = intersection_cardinality / union_cardinality # no need for check, as union_cardinality is always > 0 if prediction and reference are not empty
135
 
136
  return precision, recall, accuracy
137
 
tests.py CHANGED
@@ -8,6 +8,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
8
  All of these tests are also used for multiset configuration. So please mind this and write the test in a way that
9
  it is valid for both configurations (do not use same label multiple times).
10
  """
 
11
  def setUp(self):
12
  self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore()
13
 
@@ -149,7 +150,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
149
  "precision": 1.0,
150
  "recall": 0.5,
151
  "accuracy": 0.5,
152
- "fscore": 2/3
153
  },
154
  self.multi_label_precision_recall_accuracy_fscore.compute(
155
  predictions=[
@@ -167,7 +168,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
167
  "precision": 0.5,
168
  "recall": 1.0,
169
  "accuracy": 0.5,
170
- "fscore": 2/3
171
  },
172
  self.multi_label_precision_recall_accuracy_fscore.compute(
173
  predictions=[
@@ -184,7 +185,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
184
  {
185
  "precision": 0.5,
186
  "recall": 0.5,
187
- "accuracy": 1/3,
188
  "fscore": 0.5
189
  },
190
  self.multi_label_precision_recall_accuracy_fscore.compute(
@@ -200,10 +201,10 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
200
  def test_partial_match_multi_sample(self):
201
  self.assertDictEqual(
202
  {
203
- "precision": 2.5/3,
204
- "recall": 2/3,
205
  "accuracy": 0.5,
206
- "fscore": 2*(2.5/3 * 2/3) / (2.5/3 + 2/3)
207
  },
208
  self.multi_label_precision_recall_accuracy_fscore.compute(
209
  predictions=[
@@ -223,10 +224,10 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
223
  self.multi_label_precision_recall_accuracy_fscore.beta = 2
224
  self.assertDictEqual(
225
  {
226
- "precision": 2.5/3,
227
- "recall": 2/3,
228
  "accuracy": 0.5,
229
- "fscore": 5*(2.5/3 * 2/3) / (4*2.5/3 + 2/3)
230
  },
231
  self.multi_label_precision_recall_accuracy_fscore.compute(
232
  predictions=[
@@ -266,7 +267,8 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
266
 
267
  class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRecallAccuracyFscoreTest):
268
  def setUp(self):
269
- self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore(config_name="multiset")
 
270
 
271
  def test_multiset_eok(self):
272
  self.assertDictEqual(
@@ -291,13 +293,12 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
291
  )
292
 
293
  def test_multiset_partial_match(self):
294
-
295
  self.assertDictEqual(
296
  {
297
  "precision": 1.0,
298
  "recall": 0.5,
299
  "accuracy": 0.5,
300
- "fscore": 2/3
301
  },
302
  self.multi_label_precision_recall_accuracy_fscore.compute(
303
  predictions=[
@@ -310,15 +311,15 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
310
  )
311
 
312
  def test_multiset_partial_match_multi_sample(self):
313
- p = (1+2/3) / 2
314
- r = (3/4 + 1) / 2
315
 
316
  self.assertDictEqual(
317
  {
318
  "precision": p,
319
  "recall": r,
320
- "accuracy": (3/4 + 2/3) / 2,
321
- "fscore": 2*p*r / (p + r)
322
  },
323
  self.multi_label_precision_recall_accuracy_fscore.compute(
324
  predictions=[
@@ -331,3 +332,74 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
331
  ]
332
  )
333
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  All of these tests are also used for multiset configuration. So please mind this and write the test in a way that
9
  it is valid for both configurations (do not use same label multiple times).
10
  """
11
+
12
  def setUp(self):
13
  self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore()
14
 
 
150
  "precision": 1.0,
151
  "recall": 0.5,
152
  "accuracy": 0.5,
153
+ "fscore": 2 / 3
154
  },
155
  self.multi_label_precision_recall_accuracy_fscore.compute(
156
  predictions=[
 
168
  "precision": 0.5,
169
  "recall": 1.0,
170
  "accuracy": 0.5,
171
+ "fscore": 2 / 3
172
  },
173
  self.multi_label_precision_recall_accuracy_fscore.compute(
174
  predictions=[
 
185
  {
186
  "precision": 0.5,
187
  "recall": 0.5,
188
+ "accuracy": 1 / 3,
189
  "fscore": 0.5
190
  },
191
  self.multi_label_precision_recall_accuracy_fscore.compute(
 
201
  def test_partial_match_multi_sample(self):
202
  self.assertDictEqual(
203
  {
204
+ "precision": 2.5 / 3,
205
+ "recall": 2 / 3,
206
  "accuracy": 0.5,
207
+ "fscore": 2 * (2.5 / 3 * 2 / 3) / (2.5 / 3 + 2 / 3)
208
  },
209
  self.multi_label_precision_recall_accuracy_fscore.compute(
210
  predictions=[
 
224
  self.multi_label_precision_recall_accuracy_fscore.beta = 2
225
  self.assertDictEqual(
226
  {
227
+ "precision": 2.5 / 3,
228
+ "recall": 2 / 3,
229
  "accuracy": 0.5,
230
+ "fscore": 5 * (2.5 / 3 * 2 / 3) / (4 * 2.5 / 3 + 2 / 3)
231
  },
232
  self.multi_label_precision_recall_accuracy_fscore.compute(
233
  predictions=[
 
267
 
268
  class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRecallAccuracyFscoreTest):
269
  def setUp(self):
270
+ self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore(
271
+ config_name="multiset")
272
 
273
  def test_multiset_eok(self):
274
  self.assertDictEqual(
 
293
  )
294
 
295
  def test_multiset_partial_match(self):
 
296
  self.assertDictEqual(
297
  {
298
  "precision": 1.0,
299
  "recall": 0.5,
300
  "accuracy": 0.5,
301
+ "fscore": 2 / 3
302
  },
303
  self.multi_label_precision_recall_accuracy_fscore.compute(
304
  predictions=[
 
311
  )
312
 
313
  def test_multiset_partial_match_multi_sample(self):
314
+ p = (1 + 2 / 3) / 2
315
+ r = (3 / 4 + 1) / 2
316
 
317
  self.assertDictEqual(
318
  {
319
  "precision": p,
320
  "recall": r,
321
+ "accuracy": (3 / 4 + 2 / 3) / 2,
322
+ "fscore": 2 * p * r / (p + r)
323
  },
324
  self.multi_label_precision_recall_accuracy_fscore.compute(
325
  predictions=[
 
332
  ]
333
  )
334
  )
335
+
336
+ def test_zero_cardinality_precision(self):
337
+ self.multi_label_precision_recall_accuracy_fscore.zero_cardinality_precision = 0.5
338
+ self.assertEqual(0.5,
339
+ self.multi_label_precision_recall_accuracy_fscore.compute(
340
+ predictions=[
341
+ []
342
+ ],
343
+ references=[
344
+ [0, 1, 1],
345
+ ]
346
+ )["precision"]
347
+ )
348
+
349
+ self.assertEqual(1.0,
350
+ self.multi_label_precision_recall_accuracy_fscore.compute(
351
+ predictions=[
352
+ []
353
+ ],
354
+ references=[
355
+ [],
356
+ ]
357
+ )["precision"]
358
+ )
359
+
360
+ self.assertEqual(2 / 3,
361
+ self.multi_label_precision_recall_accuracy_fscore.compute(
362
+ predictions=[
363
+ [1, 2, 3]
364
+ ],
365
+ references=[
366
+ [1, 2],
367
+ ]
368
+ )["precision"]
369
+ )
370
+
371
+ def test_zero_cardinality_recall(self):
372
+ self.multi_label_precision_recall_accuracy_fscore.zero_cardinality_recall = 0.5
373
+ self.assertEqual(0.5,
374
+ self.multi_label_precision_recall_accuracy_fscore.compute(
375
+ predictions=[
376
+ [0, 1, 1],
377
+ ],
378
+ references=[
379
+ []
380
+ ]
381
+ )["recall"]
382
+ )
383
+
384
+ self.assertEqual(1.0,
385
+ self.multi_label_precision_recall_accuracy_fscore.compute(
386
+ predictions=[
387
+ [],
388
+ ],
389
+ references=[
390
+ [],
391
+ ]
392
+ )["recall"]
393
+ )
394
+
395
+ self.assertEqual(2 / 3,
396
+ self.multi_label_precision_recall_accuracy_fscore.compute(
397
+ predictions=[
398
+ [1, 2],
399
+ ],
400
+ references=[
401
+ [1, 2, 3]
402
+ ]
403
+ )["recall"]
404
+ )
405
+