|
import streamlit as st |
|
from PIL import Image |
|
from transformers import ViTForImageClassification |
|
from config import UNTRAINED, labels, TRAINED |
|
from utils import predict |
|
|
|
|
|
model_untrained = ViTForImageClassification.from_pretrained( |
|
UNTRAINED, |
|
num_labels=len(labels), |
|
id2label={str(i): c for i, c in enumerate(labels)}, |
|
label2id={c: str(i) for i, c in enumerate(labels)}, |
|
) |
|
|
|
model_trained = ViTForImageClassification.from_pretrained( |
|
TRAINED, |
|
num_labels=len(labels), |
|
id2label={str(i): c for i, c in enumerate(labels)}, |
|
label2id={c: str(i) for i, c in enumerate(labels)}, |
|
) |
|
|
|
st.title("Detect Hurricane Damage") |
|
|
|
file_name = st.file_uploader("Upload a satellite image") |
|
if file_name is not None: |
|
image = Image.open(file_name) |
|
st.image(image, use_container_width=True) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.markdown("## Pre-Trained Model") |
|
if file_name is not None: |
|
label = predict(model_untrained, image) |
|
st.write(f"Predicted label: {label}") |
|
|
|
with col2: |
|
st.markdown("## Fine-Tuned Model") |
|
if file_name is not None: |
|
label = predict(model_trained, image) |
|
st.write(f"Predicted label: {label}") |
|
|