In [1]:
import pandas as pd
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
import numpy as np
import os
import torch
from transformers import SegformerForSemanticSegmentation
from lib.utils import compute_mask, compute_vndvi, compute_vdi

In [2]:
# # Read raster data
# raster_path = "data/spain_2022-07-29.tif"
# raster = rxr.open_rasterio(raster_path)

# # Crop raster with GeoJSON geometry, if available
# geom_path = raster_path.replace(".tif", ".geojson")
# if os.path.exists(geom_path):
#     geom = gpd.read_file(geom_path)
#     raster = raster.rio.clip(geom.geometry)
#     raster.rio.to_raster(raster_path.replace(".tif", "_cropped.tif"))

In [3]:
def load_model(hf_path='links-ads/gaia-growseg'):
    # logger.info(f'Loading GAIA GRowSeg on {device}...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SegformerForSemanticSegmentation.from_pretrained(
        hf_path,
        num_labels=1,
        num_channels=3,
        id2label={1: 'vine'},
        label2id={'vine': 1},
        token=os.getenv('hf_read_access_token')
    )
    return model.to(device).eval()

# Load GAIA GRowSeg model
model = load_model()

In [10]:
raster_path = "data/italy_2022-06-13_cropped.tif"
patch_size = 512
stride = 256
scaling_factor = 1.0
dilate_rows = False
window_size = 360
granularity = int(window_size/8)

# raster_path = "data/spain_2022-07-29_cropped.tif"
# patch_size = 512
# stride = 256
# scaling_factor = 1.0
# dilate_rows = False
# window_size = 400
# granularity = int(window_size/8)

# raster_path = "data/portugal_2023-08-01.tif"
# patch_size = 512
# stride = 256
# scaling_factor = 1.25
# dilate_rows = False
# window_size = 80
# granularity = int(window_size/8)

raster = rxr.open_rasterio(raster_path)

# Compute mask
mask_path = raster_path.replace(".tif", "_mask.tif")
if not os.path.exists(mask_path):
    mask = compute_mask(
        raster.to_numpy(),
        model,
        patch_size=patch_size,
        stride=stride,
        scaling_factor=scaling_factor,
        rotate=False,
        batch_size=16,
    )   # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata

    # Convert mask from grayscale to RGBA, with red pixels for vine
    alpha = ((mask != 1)*255).astype(np.uint8)
    mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0)  # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine

    # Georef mask like raster
    mask_raster = xr.DataArray(
        mask_colored,
        dims=('band', 'y', 'x'),
        coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
        )
    mask_raster.rio.write_crs(raster.rio.crs, inplace=True)  # Copy CRS
    mask_raster.rio.write_transform(raster.rio.transform(), inplace=True)  # Copy affine transform

    # Save mask
    mask_raster.rio.to_raster(raster_path.replace(".tif", "_mask.tif"), compress='lzw')
else:
    mask = rxr.open_rasterio(mask_path).sel(band=1).squeeze().to_numpy()

# Compute vNDVI
vndvi_rows_path = raster_path.replace(".tif", "_vndvi_rows.tif")
vndvi_interrows_path = raster_path.replace(".tif", "_vndvi_interrows.tif")
if not os.path.exists(vndvi_rows_path) or not os.path.exists(vndvi_interrows_path):
    vndvi_rows, vndvi_interrows = compute_vndvi(
        raster.to_numpy(),
        mask,
        dilate_rows=dilate_rows,
        window_size=window_size,
        granularity=granularity,
        )    # vNDVI is RGBA

    # Georef vNDVI like raster
    vndvi_rows_raster = xr.DataArray(
        vndvi_rows.transpose(2, 0, 1),
        dims=('band', 'y', 'x'),
        coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
        )
    vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True)
    vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True)

    vndvi_interrows_raster = xr.DataArray(
        vndvi_interrows.transpose(2, 0, 1),
        dims=('band', 'y', 'x'),
        coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
        )
    vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True)
    vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True)

    # Save vNDVI
    vndvi_rows_raster.rio.to_raster(raster_path.replace(".tif", "_vndvi_rows.tif"), compress='lzw')
    vndvi_interrows_raster.rio.to_raster(raster_path.replace(".tif", "_vndvi_interrows.tif"), compress='lzw')

# Compute VDI
vdi_path = raster_path.replace(".tif", "_vdi.tif")
if not os.path.exists(vdi_path):
    vdi = compute_vdi(
        raster.to_numpy(),
        mask,
        window_size=window_size,
        granularity=granularity,
        )    # VDI is RGBA

    # Georef VDI like raster
    vdi_raster = xr.DataArray(
        vdi.transpose(2, 0, 1),
        dims=('band', 'y', 'x'),
        coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
        )
    vdi_raster.rio.write_crs(raster.rio.crs, inplace=True)
    vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True)

    # Save results
    vdi_raster.rio.to_raster(raster_path.replace(".tif", "_vdi.tif"), compress='lzw')


[32m2025-03-20 12:39:09.921[0m | [1mINFO    [0m | [36mlib.utils[0m:[36msliding_window_avg_pooling[0m:[36m308[0m - [1mExtracting patches idx...[0m
100%|█████████████████████████████████████████████| 67848/67848 [00:03<00:00, 20745.29it/s]
[32m2025-03-20 12:39:14.795[0m | [1mINFO    [0m | [36mlib.utils[0m:[36msliding_window_avg_pooling[0m:[36m308[0m - [1mExtracting patches idx...[0m
100%|█████████████████████████████████████████████| 67848/67848 [00:03<00:00, 19329.36it/s]
[32m2025-03-20 12:39:56.011[0m | [1mINFO    [0m | [36mlib.utils[0m:[36msliding_window_avg_pooling[0m:[36m308[0m - [1mExtracting patches idx...[0m
100%|██████████████████████████████████████████████| 64758/64758 [00:20<00:00, 3203.45it/s]


In [11]:
import folium
from loguru import logger

def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23):
    """Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap."""
    if isinstance(crs, int):
        crs = f"EPSG{crs}"
    assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}."
    
    m = folium.Map(
        location=location,
        zoom_start=zoom_start,
        crs=crs,
        max_zoom=max_zoom,
        tiles="OpenStreetMap",  # Esri.WorldImagery
        attributionControl=False,
        prefer_canvas=True,
    )

    # Add Esri.WorldImagery as optional basemap (radio button)
    folium.TileLayer(
        tiles="Esri.WorldImagery",
        show=False,
        overlay=False,
        control=True,
    ).add_to(m)

    return m

def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True):
    """ Create a folium image overlay from a raster filepath or xarray.DataArray. """
    if isinstance(raster_path_or_array, str):
        # Open the raster and its metadata
        logger.info(f"Opening raster: {raster_path_or_array!r}...")
        r = rxr.open_rasterio(raster_path_or_array)
    else:
        r = raster_path_or_array
    nodata = r.rio.nodata or 0
    if r.rio.crs.to_epsg() != to_crs:
        logger.info(f"Reprojecting raster to EPSG:{to_crs} with NODATA value {nodata}...")
        r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255
    r = r.transpose("y", "x", "band")
    bounds = r.rio.bounds()   # (left, bottom, right, top)

    # Create a folium image overlay
    logger.info(f"Creating overlay: {name!r}...")
    overlay = folium.raster_layers.ImageOverlay(
        image=r.to_numpy(),
        name=name,
        bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]],    # format for folium: ((bottom,left),(top,right))
        opacity=opacity,
        interactive=True,
        cross_origin=False,
        zindex=1,
        show=show,
    )

    return overlay

# Define paths
raster_path = "data/portugal_2023-08-01.tif"
mask_path = raster_path.replace('.tif', '_mask.tif')
vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif')
vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif')
vdi_path = raster_path.replace('.tif', '_vdi.tif')

# Load rasters
raster = rxr.open_rasterio(raster_path)
mask_raster = rxr.open_rasterio(mask_path)
vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path)
vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path)
vdi_raster = rxr.open_rasterio(vdi_path)

# Reproject all rasters to EPSG:4326
if raster.rio.crs.to_epsg() != 4326:
    logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
    raster = raster.rio.reproject("EPSG:4326", nodata=0)    # nodata default: 255
    mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0)
    vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0)
    vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0)
    vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0)

# Create overlays
logger.info(f'Creating RGB raster overlay...')
raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True)
logger.info(f'Creating mask overlay...')
mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False)
logger.info(f'Creating vNDVI rows overlay...')
vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False)
logger.info(f'Creating vNDVI interrows overlay...')
vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False)
logger.info(f'Creating VDI overlay...')
vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False)

[32m2025-03-20 12:40:30.816[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m76[0m - [1mReprojecting rasters to EPSG:4326 with NODATA value 0...[0m
[32m2025-03-20 12:40:52.371[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m84[0m - [1mCreating RGB raster overlay...[0m
[32m2025-03-20 12:40:52.373[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_image_overlay[0m:[36m46[0m - [1mCreating overlay: 'Orthoimage'...[0m
[32m2025-03-20 12:40:58.801[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1mCreating mask overlay...[0m
[32m2025-03-20 12:40:58.806[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_image_overlay[0m:[36m46[0m - [1mCreating overlay: 'Mask'...[0m
[32m2025-03-20 12:41:05.006[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1mCreating vNDVI rows overlay...[0m
[32m2025-03-20 12:41:05.008[0m | [1mINFO    [0m | [36m__main__[0m:[36mcreate_image_overlay

In [12]:
m = create_map()
raster_overlay.add_to(m)
mask_overlay.add_to(m)
vndvi_rows_overlay.add_to(m)
vndvi_interrows_overlay.add_to(m)
vdi_overlay.add_to(m)

# Add layer control
folium.LayerControl().add_to(m)

# Fit map to bounds
m.fit_bounds(raster_overlay.get_bounds())

# Save map
map_path = raster_path.replace('.tif', '.html')
m.save(map_path)