till-onethousand commited on
Commit
9506f43
·
1 Parent(s): 4b848f2

base model

Browse files
Files changed (3) hide show
  1. app.py +19 -13
  2. config.py +7 -0
  3. utils.py +15 -0
app.py CHANGED
@@ -1,20 +1,26 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  from PIL import Image
 
 
 
4
 
5
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
6
 
7
- st.title("Hot Dog? Or Not?")
 
 
 
 
 
8
 
9
- file_name = st.file_uploader("Upload a hot dog candidate image")
10
 
11
- if file_name is not None:
12
- col1, col2 = st.columns(2)
 
 
13
 
14
- image = Image.open(file_name)
15
- col1.image(image, use_column_width=True)
16
- predictions = pipeline(image)
17
-
18
- col2.header("Probabilities")
19
- for p in predictions:
20
- col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
 
1
  import streamlit as st
 
2
  from PIL import Image
3
+ from transformers import ViTForImageClassification
4
+ from config import UNTRAINED, labels
5
+ from utils import predict
6
 
 
7
 
8
+ model_untrained = ViTForImageClassification.from_pretrained(
9
+ UNTRAINED,
10
+ num_labels=len(labels),
11
+ id2label={str(i): c for i, c in enumerate(labels)},
12
+ label2id={c: str(i) for i, c in enumerate(labels)},
13
+ )
14
 
15
+ st.title("Detect Hurricane Damage")
16
 
17
+ col1, col2 = st.columns(2)
18
+ with col1:
19
+ st.markdown("## Pre-Trained Model")
20
+ file_name = st.file_uploader("Upload a satellite image")
21
 
22
+ if file_name is not None:
23
+ image = Image.open(file_name)
24
+ col1.image(image, use_container_width=True)
25
+ label = predict
26
+ st.write(f"Predicted label: {label}")
 
 
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import datasets import load_dataset
2
+
3
+ dataset_name = "jonathan-roberts1/Satellite-Images-of-Hurricane-Damage"
4
+ ds = load_dataset(dataset_name)
5
+ labels = ds['train'].features['label'].names
6
+
7
+ UNTRAINED = 'google/vit-base-patch16-224-in21k'
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTFeatureExtractor
3
+ from config import UNTRAINED
4
+
5
+ feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED)
6
+
7
+ def predict(model, image):
8
+ inputs = feature_extractor(image, return_tensors="pt")
9
+
10
+ with torch.no_grad():
11
+ logits = model(**inputs).logits
12
+
13
+ # model predicts one of the 1000 ImageNet classes
14
+ predicted_label = logits.argmax(-1).item()
15
+ return model.config.id2label[str(predicted_label)]