ysdaml4 / app.py
ssbars's picture
v2
12faaae
import streamlit as st
# Set up the Streamlit page - this must be the first st command
st.set_page_config(
page_title="Paper Classification Service",
page_icon="πŸ“š",
layout="wide"
)
import PyPDF2
import io
from model import PaperClassifier
# Initialize the classifier with model selection
@st.cache_resource
def load_classifier(model_type):
return PaperClassifier(model_type)
# Cache the PDF text extraction
@st.cache_data
def extract_pdf_text(pdf_bytes):
"""Extract text from PDF and try to separate title and abstract"""
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
# Try to extract title and abstract
lines = text.split('\n')
title = lines[0] if lines else ""
abstract = "\n".join(lines[1:]) if len(lines) > 1 else ""
return title.strip(), abstract.strip()
# Get available models for selection
available_models = list(PaperClassifier.AVAILABLE_MODELS.keys())
# Add model selection to sidebar
st.sidebar.title("Model Settings")
selected_model = st.sidebar.selectbox(
"Select Model",
available_models,
index=0,
help="Choose the model to use for classification"
)
# Display model information
model_info = PaperClassifier.AVAILABLE_MODELS[selected_model]
st.sidebar.markdown(f"""
### Selected Model
**Name:** {model_info['name']}
**Description:** {model_info['description']}
""")
# Initialize the classifier with selected model
classifier = load_classifier(selected_model)
# Title and description
st.title("πŸ“š Academic Paper Classification")
st.markdown("""
This service helps you classify academic papers into different categories.
You can either:
- Enter the paper's title and abstract separately
- Upload a PDF file
""")
# Create two columns for input methods
col1, col2 = st.columns(2)
with col1:
st.subheader("Option 1: Manual Input")
# Title input
title_input = st.text_input(
"Paper Title:",
placeholder="Enter the paper title..."
)
# Abstract input
abstract_input = st.text_area(
"Paper Abstract (optional):",
height=200,
placeholder="Enter the paper abstract (optional)..."
)
if st.button("Classify Paper"):
if title_input.strip():
with st.spinner("Classifying..."):
result = classifier.classify_paper(
title=title_input,
abstract=abstract_input if abstract_input.strip() else None
)
st.success("Classification Complete!")
st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
st.write(f"**Model Used:** {result['model_used']}")
# Show top categories
st.subheader("Top Categories (95% Confidence)")
total_prob = 0
for cat_info in result['top_categories']:
prob = cat_info['probability']
total_prob += prob
st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
st.info(f"Total probability of shown categories: {total_prob:.1%}")
else:
st.warning("Please enter at least the paper title.")
with col2:
st.subheader("Option 2: PDF Upload")
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
if st.button("Classify PDF"):
try:
with st.spinner("Processing PDF..."):
# Extract title and abstract from PDF
title, abstract = extract_pdf_text(uploaded_file.read())
if not title:
st.error("Could not extract title from PDF.")
st.stop()
# Show extracted text
with st.expander("Show extracted text"):
st.write("**Extracted Title:**")
st.write(title)
if abstract:
st.write("**Extracted Abstract:**")
st.write(abstract)
# Classify the paper
result = classifier.classify_paper(
title=title,
abstract=abstract if abstract else None
)
st.success("Classification Complete!")
st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
st.write(f"**Model Used:** {result['model_used']}")
# Show top categories
st.subheader("Top Categories (95% Confidence)")
total_prob = 0
for cat_info in result['top_categories']:
prob = cat_info['probability']
total_prob += prob
st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
st.info(f"Total probability of shown categories: {total_prob:.1%}")
except Exception as e:
st.error(f"Error processing PDF: {str(e)}")
# Add information about the models
st.sidebar.markdown("---")
st.sidebar.title("Available Models")
st.sidebar.markdown("""
- **DistilBERT**: Fast and lightweight
- **DeBERTa v3**: Advanced performance
- **T5**: Versatile text-to-text
- **RoBERTa**: Strong performance
- **SciBERT**: Specialized for science
""")
# Add footer
st.markdown("---")
st.markdown("Made with ❀️ using Streamlit and Transformers")