edansa-2019-cnn10-32k-t / postprocess_predictions.py
trianand's picture
Upload postprocess_predictions.py (#4)
f845a81 verified
import os
import sys
import pandas as pd
import numpy as np
import argparse
from glob import glob
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
required=True
)
args = parser.parse_args()
predictions_path = args.path
threshold = .4
# Load and prepare predictions for new thresholds
pred = pd.read_csv(predictions_path).set_index("filename")
pred = pred.groupby("filename").apply(
lambda x: pd.Series({
"Anth": (x["Anth"] > .5).sum() >= 5,
"Bio": (x["Bio"] > .5).sum() >= 2,
"Geo": (x["Geo"] > .5).sum() >= 3,
"Sil": (x["Sil"] > threshold).sum() >= 55,
})
)
pred["Sil"] = pred.apply(lambda x: ~x["Anth"] & ~x["Bio"] & ~x["Geo"], axis=1)
pred = pred.astype(int)
pred.reset_index(inplace=True)
print(pred.tail())
pred.to_csv(predictions_path.split(".csv")[0] + "_postprocessed.csv", index=False)