Upload postprocess_predictions.py (#4)
Browse files- Upload postprocess_predictions.py (ebdcf7085ca211fbcbc7ab70c517956942e9ccc2)
Co-authored-by: Alexander Gebhard <[email protected]>
- postprocess_predictions.py +36 -0
postprocess_predictions.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
from glob import glob
|
8 |
+
|
9 |
+
|
10 |
+
if __name__=='__main__':
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument(
|
13 |
+
"--path",
|
14 |
+
type=str,
|
15 |
+
required=True
|
16 |
+
)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
predictions_path = args.path
|
20 |
+
threshold = .4
|
21 |
+
|
22 |
+
# Load and prepare predictions for new thresholds
|
23 |
+
pred = pd.read_csv(predictions_path).set_index("filename")
|
24 |
+
pred = pred.groupby("filename").apply(
|
25 |
+
lambda x: pd.Series({
|
26 |
+
"Anth": (x["Anth"] > .5).sum() >= 5,
|
27 |
+
"Bio": (x["Bio"] > .5).sum() >= 2,
|
28 |
+
"Geo": (x["Geo"] > .5).sum() >= 3,
|
29 |
+
"Sil": (x["Sil"] > threshold).sum() >= 55,
|
30 |
+
})
|
31 |
+
)
|
32 |
+
pred["Sil"] = pred.apply(lambda x: ~x["Anth"] & ~x["Bio"] & ~x["Geo"], axis=1)
|
33 |
+
pred = pred.astype(int)
|
34 |
+
pred.reset_index(inplace=True)
|
35 |
+
print(pred.tail())
|
36 |
+
pred.to_csv(predictions_path.split(".csv")[0] + "_postprocessed.csv", index=False)
|