Spaces:
Starting
Starting
import os | |
from datetime import datetime | |
from pathlib import Path | |
import torch | |
import folium | |
import streamlit as st | |
from loguru import logger | |
from tqdm import tqdm | |
from streamlit_folium import st_folium | |
from transformers import SegformerForSemanticSegmentation | |
from lib.folium import ( | |
get_clean_rendering_container, | |
create_map, | |
process_raster_and_overlays, | |
) | |
import streamlit.components.v1 as components | |
# Page configs | |
st.set_page_config(page_title="GrowSeg Demo", page_icon="π", layout="wide") | |
# BUGFIX (https://discuss.streamlit.io/t/message-error-about-torch/90886/6) | |
torch.classes.__path__ = [] | |
# Interoperability with tqdm (https://loguru.readthedocs.io/en/stable/resources/recipes.html#interoperability-with-tqdm-iterations) | |
logger.remove() | |
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, format="<green>{message}</green>") | |
def load_model(hf_path='links-ads/gaia-growseg'): | |
# logger.info(f'Loading GAIA GRowSeg on {device}...') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
hf_path, | |
num_labels=1, | |
num_channels=3, | |
id2label={1: 'vine'}, | |
label2id={'vine': 1}, | |
token=os.getenv('hf_read_access_token') | |
) | |
return model.to(device).eval() | |
# Load GAIA GRowSeg model | |
model = load_model() | |
def change_key(): | |
st.session_state["key_map"] = str(datetime.now()) | |
# Create selection menu | |
container_predictions = st.container(border=True) | |
with container_predictions: | |
col1, col2 = st.columns([0.3, 0.7]) | |
with col1: | |
# raster_path = st.text_input( | |
# "Enter the path to your local file: ", | |
# key="raster_path_block", | |
# ) | |
# raster_path = st.file_uploader( | |
# "Upload a raster file", | |
# type=["tif", "tiff"], | |
# key="raster_path_block", | |
# ) | |
precomputed_map_path = None | |
raster_path = None | |
raster_selection = st.selectbox( | |
"Select an example or your own raster...", | |
options=[ | |
"Italy", | |
"Portugal", | |
"Spain", | |
"Upload file...", | |
], | |
key="raster_selection_block", | |
index=None, | |
placeholder="Choose an example or upload your own raster", | |
) | |
if raster_selection == "Italy": | |
st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.") | |
# TODO GEOSERVER | |
#precomputed_map_path = "data/italy_2022-06-13_cropped.html" | |
elif raster_selection == "Portugal": | |
precomputed_map_path = "data/portugal_2023-08-01.html" | |
elif raster_selection == "Spain": | |
st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.") | |
#precomputed_map_path = "data/spain_2022-07-29_cropped.html" | |
elif raster_selection == "Upload file...": | |
uploaded_file = st.file_uploader( | |
"Upload a raster file", | |
type=["tif"], | |
key="uploaded_file_block", | |
) | |
if uploaded_file is not None: | |
fn = Path(uploaded_file.name).name | |
print(fn) | |
raster_path = os.path.join("temp", fn) | |
with open(raster_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
is_raster_path_selected = raster_path is not None | |
is_precomputed_map_selected = precomputed_map_path is not None | |
with col2: | |
with st.container(): | |
st.write("######") | |
with st.expander("More info on the model"): | |
st.write(""" | |
Under the hood, this model is a SegFormer-b5, trained on | |
UAV-acquired vineyard orthoimages and their ground-truth | |
delineation masks. Paper will be available soon. Stay tuned! | |
""") | |
if not is_precomputed_map_selected and is_raster_path_selected: | |
progress_bar = st.progress(0, text="Begin processing...") | |
# Process raster and get overlays | |
overlays = process_raster_and_overlays(raster_path, model, _progress_bar=progress_bar) | |
#progress_bar.empty() | |
#container = get_clean_rendering_container(raster_path) | |
container = st.empty() | |
# draw map | |
interactive_map = create_map() | |
if is_raster_path_selected: | |
# Add overlays to map | |
for overlay in overlays: | |
overlay.add_to(interactive_map) | |
with container.form(key="form1"): | |
if is_precomputed_map_selected: | |
# Load precomputed map | |
# interactive_map = folium.Map(location=[35, -10], zoom_start=6) | |
# folium.IFrame( | |
# precomputed_map_path, | |
# width=1000, | |
# height=500, | |
# ).add_to(interactive_map) | |
with open(precomputed_map_path, 'r') as f: | |
html_content = f.read() | |
interactive_map = components.html(html_content, height=500) | |
else: | |
if is_raster_path_selected: | |
# Center map on overlays | |
bounds = overlays[0].get_bounds() | |
interactive_map.fit_bounds(bounds) | |
else: | |
# Center map on Europe | |
interactive_map.fit_bounds([[35, -10], [60, 40]]) | |
# Add Layer Control (first remove existing one) | |
for key, child in list(interactive_map._children.items()): | |
if isinstance(child, folium.map.LayerControl): | |
del interactive_map._children[key] | |
folium.LayerControl().add_to(interactive_map) | |
# Folium Map component | |
output_map = st_folium( | |
interactive_map, | |
width=None, | |
height=500, | |
returned_objects=["all_drawings"], | |
key=st.session_state.get("key_map", "key_map"), # This is a workaround to force the map to recenter | |
) | |
# Recenter map | |
submit = st.form_submit_button("Recenter map") | |