tommonopoli's picture
load app & the rest
03e7460
import os
from datetime import datetime
from pathlib import Path
import torch
import folium
import streamlit as st
from loguru import logger
from tqdm import tqdm
from streamlit_folium import st_folium
from transformers import SegformerForSemanticSegmentation
from lib.folium import (
get_clean_rendering_container,
create_map,
process_raster_and_overlays,
)
import streamlit.components.v1 as components
# Page configs
st.set_page_config(page_title="GrowSeg Demo", page_icon="πŸ‡", layout="wide")
# BUGFIX (https://discuss.streamlit.io/t/message-error-about-torch/90886/6)
torch.classes.__path__ = []
# Interoperability with tqdm (https://loguru.readthedocs.io/en/stable/resources/recipes.html#interoperability-with-tqdm-iterations)
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, format="<green>{message}</green>")
@st.cache_resource
def load_model(hf_path='links-ads/gaia-growseg'):
# logger.info(f'Loading GAIA GRowSeg on {device}...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SegformerForSemanticSegmentation.from_pretrained(
hf_path,
num_labels=1,
num_channels=3,
id2label={1: 'vine'},
label2id={'vine': 1},
token=os.getenv('hf_read_access_token')
)
return model.to(device).eval()
# Load GAIA GRowSeg model
model = load_model()
def change_key():
st.session_state["key_map"] = str(datetime.now())
# Create selection menu
container_predictions = st.container(border=True)
with container_predictions:
col1, col2 = st.columns([0.3, 0.7])
with col1:
# raster_path = st.text_input(
# "Enter the path to your local file: ",
# key="raster_path_block",
# )
# raster_path = st.file_uploader(
# "Upload a raster file",
# type=["tif", "tiff"],
# key="raster_path_block",
# )
precomputed_map_path = None
raster_path = None
raster_selection = st.selectbox(
"Select an example or your own raster...",
options=[
"Italy",
"Portugal",
"Spain",
"Upload file...",
],
key="raster_selection_block",
index=None,
placeholder="Choose an example or upload your own raster",
)
if raster_selection == "Italy":
st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
# TODO GEOSERVER
#precomputed_map_path = "data/italy_2022-06-13_cropped.html"
elif raster_selection == "Portugal":
precomputed_map_path = "data/portugal_2023-08-01.html"
elif raster_selection == "Spain":
st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
#precomputed_map_path = "data/spain_2022-07-29_cropped.html"
elif raster_selection == "Upload file...":
uploaded_file = st.file_uploader(
"Upload a raster file",
type=["tif"],
key="uploaded_file_block",
)
if uploaded_file is not None:
fn = Path(uploaded_file.name).name
print(fn)
raster_path = os.path.join("temp", fn)
with open(raster_path, "wb") as f:
f.write(uploaded_file.getbuffer())
is_raster_path_selected = raster_path is not None
is_precomputed_map_selected = precomputed_map_path is not None
with col2:
with st.container():
st.write("######")
with st.expander("More info on the model"):
st.write("""
Under the hood, this model is a SegFormer-b5, trained on
UAV-acquired vineyard orthoimages and their ground-truth
delineation masks. Paper will be available soon. Stay tuned!
""")
if not is_precomputed_map_selected and is_raster_path_selected:
progress_bar = st.progress(0, text="Begin processing...")
# Process raster and get overlays
overlays = process_raster_and_overlays(raster_path, model, _progress_bar=progress_bar)
#progress_bar.empty()
#container = get_clean_rendering_container(raster_path)
container = st.empty()
# draw map
interactive_map = create_map()
if is_raster_path_selected:
# Add overlays to map
for overlay in overlays:
overlay.add_to(interactive_map)
with container.form(key="form1"):
if is_precomputed_map_selected:
# Load precomputed map
# interactive_map = folium.Map(location=[35, -10], zoom_start=6)
# folium.IFrame(
# precomputed_map_path,
# width=1000,
# height=500,
# ).add_to(interactive_map)
with open(precomputed_map_path, 'r') as f:
html_content = f.read()
interactive_map = components.html(html_content, height=500)
else:
if is_raster_path_selected:
# Center map on overlays
bounds = overlays[0].get_bounds()
interactive_map.fit_bounds(bounds)
else:
# Center map on Europe
interactive_map.fit_bounds([[35, -10], [60, 40]])
# Add Layer Control (first remove existing one)
for key, child in list(interactive_map._children.items()):
if isinstance(child, folium.map.LayerControl):
del interactive_map._children[key]
folium.LayerControl().add_to(interactive_map)
# Folium Map component
output_map = st_folium(
interactive_map,
width=None,
height=500,
returned_objects=["all_drawings"],
key=st.session_state.get("key_map", "key_map"), # This is a workaround to force the map to recenter
)
# Recenter map
submit = st.form_submit_button("Recenter map")