import os import sys import pandas as pd import numpy as np import argparse from glob import glob if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument( "--path", type=str, required=True ) args = parser.parse_args() predictions_path = args.path threshold = .4 # Load and prepare predictions for new thresholds pred = pd.read_csv(predictions_path).set_index("filename") pred = pred.groupby("filename").apply( lambda x: pd.Series({ "Anth": (x["Anth"] > .5).sum() >= 5, "Bio": (x["Bio"] > .5).sum() >= 2, "Geo": (x["Geo"] > .5).sum() >= 3, "Sil": (x["Sil"] > threshold).sum() >= 55, }) ) pred["Sil"] = pred.apply(lambda x: ~x["Anth"] & ~x["Bio"] & ~x["Geo"], axis=1) pred = pred.astype(int) pred.reset_index(inplace=True) print(pred.tail()) pred.to_csv(predictions_path.split(".csv")[0] + "_postprocessed.csv", index=False)