trianand AlexanderGbd commited on
Commit
f845a81
·
verified ·
1 Parent(s): 6c9e90e

Upload postprocess_predictions.py (#4)

Browse files

- Upload postprocess_predictions.py (ebdcf7085ca211fbcbc7ab70c517956942e9ccc2)


Co-authored-by: Alexander Gebhard <[email protected]>

Files changed (1) hide show
  1. postprocess_predictions.py +36 -0
postprocess_predictions.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import pandas as pd
5
+ import numpy as np
6
+ import argparse
7
+ from glob import glob
8
+
9
+
10
+ if __name__=='__main__':
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--path",
14
+ type=str,
15
+ required=True
16
+ )
17
+ args = parser.parse_args()
18
+
19
+ predictions_path = args.path
20
+ threshold = .4
21
+
22
+ # Load and prepare predictions for new thresholds
23
+ pred = pd.read_csv(predictions_path).set_index("filename")
24
+ pred = pred.groupby("filename").apply(
25
+ lambda x: pd.Series({
26
+ "Anth": (x["Anth"] > .5).sum() >= 5,
27
+ "Bio": (x["Bio"] > .5).sum() >= 2,
28
+ "Geo": (x["Geo"] > .5).sum() >= 3,
29
+ "Sil": (x["Sil"] > threshold).sum() >= 55,
30
+ })
31
+ )
32
+ pred["Sil"] = pred.apply(lambda x: ~x["Anth"] & ~x["Bio"] & ~x["Geo"], axis=1)
33
+ pred = pred.astype(int)
34
+ pred.reset_index(inplace=True)
35
+ print(pred.tail())
36
+ pred.to_csv(predictions_path.split(".csv")[0] + "_postprocessed.csv", index=False)