archis99's picture
Initial Commit
d587b0b
import sys
import os
from pathlib import Path
# PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# if PROJECT_ROOT not in sys.path:
# sys.path.insert(0, PROJECT_ROOT)
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import streamlit as st
import pandas as pd
from backend.pipelines.run_inference import predict
from io import BytesIO
st.set_page_config(page_title="Study Status Prediction", page_icon="๐Ÿ“Š", layout="wide")
st.title("๐Ÿ“Š Study Status Prediction")
st.markdown("Upload a CSV file or manually enter study details to predict whether a study is **COMPLETED** or **NOT COMPLETED**.")
# Tabs for CSV Upload and Manual Entry
tab1, tab2 = st.tabs(["๐Ÿ“‚ Upload CSV", "โœ๏ธ Manual Entry"])
# --- Option 1: CSV Upload ---
with tab1:
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
if uploaded_file:
df_new = pd.read_csv(uploaded_file)
preds = predict(df_new)
# Only keep final predictions
final_preds = pd.DataFrame({"Final Prediction": preds["final_predictions"]})
# Display a preview
st.subheader("๐Ÿ”Ž Predictions Preview")
st.dataframe(final_preds.head())
# Download button for predictions
csv_buffer = BytesIO()
final_preds.to_csv(csv_buffer, index=False)
st.download_button(
label="๐Ÿ“ฅ Download Predictions CSV",
data=csv_buffer.getvalue(),
file_name="predictions.csv",
mime="text/csv"
)
# --- Option 2: Manual Entry ---
with tab2:
st.subheader("โœ๏ธ Enter Study Details")
st.markdown("Fill in the fields below to predict the study status.")
# Placeholders for all the features
nct_number = st.text_input("NCT Number", placeholder="e.g., NCT01234567")
study_title = st.text_area("Study Title", placeholder="e.g., A Study of Drug X in Treating Lung Cancer")
study_url = st.text_input("Study URL", placeholder="e.g., https://clinicaltrials.gov/ct2/show/NCT01234567")
acronym = st.text_input("Acronym", placeholder="e.g., LUNG-X")
brief_summary = st.text_area("Brief Summary", placeholder="e.g., This is a phase 3 trial evaluating the effectiveness of Drug X for lung cancer.")
study_results = st.selectbox("Study Results", ["YES", "NO"])
conditions = st.text_input("Conditions", placeholder="e.g., Lung Cancer")
interventions = st.text_input("Interventions", placeholder="e.g., Drug Y")
primary_outcome = st.text_input("Primary Outcome Measures", placeholder="e.g., Survival rate")
secondary_outcome = st.text_input("Secondary Outcome Measures", placeholder="e.g., Side effects")
other_outcome = st.text_input("Other Outcome Measures", placeholder="Optional")
sponsor = st.text_input("Sponsor", placeholder="e.g., ABC Research")
collaborators = st.text_input("Collaborators", placeholder="e.g., University of SFX")
sex = st.selectbox("Sex", ["ALL", "MALE", "FEMALE"])
age = st.selectbox("Age", ["ADULT, OLDER_ADULT",
"ADULT",
"CHILD, ADULT, OLDER_ADULT",
"CHILD",
"CHILD, ADULT",
"OLDER_ADULT"])
phases = st.selectbox("Phases", ["PHASE2",
"PHASE1",
"PHASE4",
"PHASE3",
"PHASE1|PHASE2",
"PHASE2|PHASE3",
"EARLY_PHASE1"])
enrollment = st.number_input("Enrollment", min_value=0, step=1, placeholder="e.g., 500")
funder_type = st.selectbox("Funder Type", ["OTHER",
"INDUSTRY",
"NIH",
"OTHER_GOV",
"NETWORK",
"FED",
"INDIV",
"UNKNOWN",
"AMBIG"])
study_type = st.selectbox("Study Type", ["INTERVENTIONAL", "OBSERVATIONAL"])
study_design = st.text_area("Study Design", placeholder="e.g., Intervention Model: PARALLEL | Masking: SINGLE (INVESTIGATOR)")
other_ids = st.text_input("Other IDs", placeholder="e.g., ABC-123")
start_date = st.text_input("Start Date", placeholder="e.g., January 2023")
primary_completion_date = st.text_input("Primary Completion Date", placeholder="e.g., December 2025")
completion_date = st.text_input("Completion Date", placeholder="e.g., June 2026")
first_posted = st.text_input("First Posted", placeholder="e.g., February 2023")
results_first_posted = st.text_input("Results First Posted", placeholder="e.g., N/A")
last_update_posted = st.text_input("Last Update Posted", placeholder="e.g., September 2025")
locations = st.text_area("Locations", placeholder="e.g., New York, USA")
study_documents = st.text_area("Study Documents", placeholder="e.g., Protocol PDF")
if st.button("๐Ÿ”ฎ Predict Status"):
single_data = {
"NCT Number": nct_number,
"Study Title": study_title,
"Study URL": study_url,
"Acronym": acronym,
"Brief Summary": brief_summary,
"Study Results": study_results,
"Conditions": conditions,
"Interventions": interventions,
"Primary Outcome Measures": primary_outcome,
"Secondary Outcome Measures": secondary_outcome,
"Other Outcome Measures": other_outcome,
"Sponsor": sponsor,
"Collaborators": collaborators,
"Sex": sex,
"Age": age,
"Phases": phases,
"Enrollment": enrollment,
"Funder Type": funder_type,
"Study Type": study_type,
"Study Design": study_design,
"Other IDs": other_ids,
"Start Date": start_date,
"Primary Completion Date": primary_completion_date,
"Completion Date": completion_date,
"First Posted": first_posted,
"Results First Posted": results_first_posted,
"Last Update Posted": last_update_posted,
"Locations": locations,
"Study Documents": study_documents,
}
df_single = pd.DataFrame([single_data])
preds = predict(df_single)
# Show final prediction with animation
final_label = preds["final_predictions"][0]
if final_label == "COMPLETED":
st.success(f"โœ… Prediction: **{final_label}**", icon="๐ŸŽ‰")
else:
st.error(f"โŒ Prediction: **{final_label}**", icon="โš ๏ธ")