Spaces:
Sleeping
Sleeping
# app.py | |
from shiny import App, ui, reactive, render | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import rasterio | |
from rasterio.plot import show | |
import geopandas as gpd | |
from ipyleaflet import Map, TileLayer, basemaps, ColorMap, RasterLayer, LegendControl, GeoJSON, MarkerCluster, Marker, DivIcon, Polygon | |
from shinywidgets import output_widget, register_widget | |
import plotnine as p9 | |
import matplotlib.cm as cm | |
import matplotlib.colors as colors | |
import os | |
import base64 | |
import tempfile | |
import json | |
from datetime import datetime | |
from helpers.fetch_data import fetch_data | |
from helpers.residuals import get_residual_plot | |
from helpers.reduce_precision import reduce_coordinate_precision | |
from shapely import LineString, Polygon | |
from matplotlib.colors import BoundaryNorm | |
import io | |
# ------------------------------ | |
# 1. Data & Config | |
# ------------------------------ | |
# Define time periods corresponding to each band in the GeoTIFF | |
time_periods = ["1990β1992", "1993β1995", "1996β1998", "1999β2001", "2002β2004", | |
"2005β2007", "2008β2010", "2011β2013", "2014β2016", "2017β2019"] | |
# Load GeoTIFF data (multi-band) | |
wealth_stack = rasterio.open("wealth_map.tif") | |
#load country data | |
with open('data/no_somaliland.geojson') as a: | |
country_json = json.load(a) | |
#load IWI by country | |
IWI_df = pd.read_csv('data/mean_IWI_by_country.csv') | |
#load residual data | |
residual_data = pd.read_csv('data/residual_by_country.csv') | |
selected_map = reactive.Value(None) | |
#load band 1 data by default | |
with open('data/simplified_band_1.geojson') as w: | |
band_1_data = json.load(w) | |
band_1_data = reduce_coordinate_precision(band_1_data, precision=5) #reduce precision to aid in rendering | |
IWI_values = [feature['properties']['IWI'] | |
for feature in band_1_data['features']] | |
# Define IWI value bins and their corresponding color ranges | |
iwi_bins = [0.052, 0.140, 0.161, 0.187, 0.240, 0.696] # Custom bin values | |
iwi_labels = [ | |
"0.052 β 0.140", | |
"0.140 β 0.161", | |
"0.161 β 0.187", | |
"0.187 β 0.240", | |
"0.240 β 0.696" | |
] | |
iwi_mappings = { | |
"0.052 β 0.140": "#0d0887", | |
"0.140 β 0.161": "#7d03a8", | |
"0.161 β 0.187": "#cb4679", | |
"0.187 β 0.240": "#f89441", | |
"0.240 β 0.696": "#f0f921" | |
} | |
# Generate a colormap and norm based on the custom bins | |
colormap = cm.get_cmap("plasma") # Choose your colormap | |
norm = BoundaryNorm(iwi_bins, colormap.N) | |
# Function to get color based on IWI value | |
def get_color(iwi): | |
# Find the color for the given IWI value based on the colormap and norm | |
rgba = colormap(norm(iwi)) # Convert to RGBA | |
return colors.to_hex(rgba) # Convert to HEX | |
# Function to clean up out-of-range values and get values | |
def get_clean_values(src, band_idx=1): | |
band_data = src.read(band_idx) | |
# Replace out-of-range values with NaN | |
band_data[(band_data <= 0) | (band_data > 1)] = np.nan | |
return band_data | |
# Get all values across all bands for quantiles | |
all_vals = [] | |
for i in range(1, wealth_stack.count + 1): | |
vals = get_clean_values(wealth_stack, i).flatten() | |
all_vals.extend(vals[~np.isnan(vals)]) | |
all_vals = np.array(all_vals) | |
q_breaks_legend = np.quantile(all_vals, np.linspace(0, 1, 6)) | |
q_breaks = np.quantile(all_vals, np.linspace(0, 1, 11)) | |
# Get raster bounds for proper positioning on the map | |
bounds = [[wealth_stack.bounds.bottom, wealth_stack.bounds.left], | |
[wealth_stack.bounds.top, wealth_stack.bounds.right]] | |
# Load improvement data (change in IWI by state/province) | |
# In real app, adjust path | |
improvement_data = pd.read_csv("data/poverty_improvement_by_state.csv") | |
# Pre-calculate the mean IWI for each band (for the "Trends Over Time" chart) | |
band_means = [] | |
for i in range(1, wealth_stack.count + 1): | |
vals = get_clean_values(wealth_stack, i).flatten() | |
band_means.append(np.nanmean(vals)) | |
# ------------------------------ | |
# 2. UI | |
# ------------------------------ | |
# Custom CSS for OCR A Std font and other styling | |
css = """ | |
@import url('https://fonts.cdnfonts.com/css/ocr-a-std'); | |
body { | |
font-family: 'OCR A Std', monospace !important; | |
} | |
.slider-animate-button { | |
background-color: #ffffff !important; | |
color: #000000 !important; | |
border: 2px solid #000000 !important; | |
border-radius: 5px !important; | |
padding: 5px 10px !important; | |
top: 10px !important; | |
} | |
.value-box { | |
margin-bottom: 15px; | |
padding: 15px; | |
border-radius: 5px; | |
color: white; | |
} | |
.green-box { | |
background-color: #00a65a; | |
} | |
.blue-box { | |
background-color: #0073b7; | |
} | |
.red-box { | |
background-color: #dd4b39; | |
} | |
.share-button { | |
display: inline-flex; | |
align-items: center; | |
justify-content: center; | |
gap: 8px; | |
padding: 5px 10px; | |
font-size: 16px; | |
font-weight: normal; | |
color: #000; | |
background-color: #fff; | |
border: 1px solid #ddd; | |
border-radius: 6px; | |
cursor: pointer; | |
box-shadow: 0 1.5px 0 #000; | |
} | |
.title-text { | |
font-family: 'OCR A Std', monospace; | |
font-size: 18px; | |
} | |
.subtitle-text { | |
font-family: 'OCR A Std', monospace; | |
font-size: 14px; | |
} | |
#improvement_table .shiny-data-grid { | |
width: 100% !important; | |
} | |
.nav-link { | |
color: white !important; | |
} | |
""" | |
# Share button HTML | |
share_button_html = """ | |
<button id="share-button" class="share-button"> | |
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
<circle cx="18" cy="5" r="3"></circle> | |
<circle cx="6" cy="12" r="3"></circle> | |
<circle cx="18" cy="19" r="3"></circle> | |
<line x1="8.59" y1="13.51" x2="15.42" y2="17.49"></line> | |
<line x1="15.41" y1="6.51" x2="8.59" y2="10.49"></line> | |
</svg> | |
<strong>Share</strong> | |
</button> | |
<script> | |
(function() { | |
const shareBtn = document.getElementById('share-button'); | |
// Reusable helper function to show a small "Copied!" message | |
function showCopyNotification() { | |
const notification = document.createElement('div'); | |
notification.innerText = 'Copied to clipboard'; | |
notification.style.position = 'fixed'; | |
notification.style.bottom = '20px'; | |
notification.style.right = '20px'; | |
notification.style.backgroundColor = 'rgba(0, 0, 0, 0.8)'; | |
notification.style.color = '#fff'; | |
notification.style.padding = '8px 12px'; | |
notification.style.borderRadius = '4px'; | |
notification.style.zIndex = '9999'; | |
document.body.appendChild(notification); | |
setTimeout(() => { notification.remove(); }, 2000); | |
} | |
shareBtn.addEventListener('click', function() { | |
const currentURL = window.location.href; | |
const pageTitle = document.title || 'Check this out!'; | |
// If browser supports Web Share API | |
if (navigator.share) { | |
navigator.share({ | |
title: pageTitle, | |
text: '', | |
url: currentURL | |
}) | |
.catch((error) => { | |
console.log('Sharing failed', error); | |
}); | |
} else { | |
// Fallback: Copy URL | |
if (navigator.clipboard && navigator.clipboard.writeText) { | |
navigator.clipboard.writeText(currentURL).then(() => { | |
showCopyNotification(); | |
}, (err) => { | |
console.error('Could not copy text: ', err); | |
}); | |
} else { | |
// Double fallback for older browsers | |
const textArea = document.createElement('textarea'); | |
textArea.value = currentURL; | |
document.body.appendChild(textArea); | |
textArea.select(); | |
try { | |
document.execCommand('copy'); | |
showCopyNotification(); | |
} catch (err) { | |
alert('Please copy this link:\\n' + currentURL); | |
} | |
document.body.removeChild(textArea); | |
} | |
} | |
}); | |
})(); | |
</script> | |
""" | |
# Create the app UI with dashboard layout | |
app_ui = ui.page_fluid( | |
ui.head_content( | |
ui.tags.style(css) | |
), | |
ui.page_navbar( | |
ui.nav_panel("Wealth Map", | |
ui.layout_sidebar( | |
ui.sidebar( | |
ui.h4("Map Controls"), | |
ui.input_switch( | |
"SelectedMap", "Enable Country View", False), | |
ui.input_slider( | |
"time_index", | |
"Select Time Period (Years):", | |
min=1, | |
max=len(time_periods), | |
value=1, | |
step=1, | |
animate=True | |
), | |
ui.strong("Currently Selected: "), | |
ui.output_text("current_year_range", inline=True), | |
ui.input_select( | |
"color_palette", | |
"Select Color Palette:", | |
{ | |
"blue": "blue", | |
"red": "red", | |
"orange": "orange", | |
"purple": "purple", | |
"Spectral": "Spectral (Brewer)" | |
}, | |
selected="red" | |
), | |
ui.input_slider( | |
"opacity", | |
"Map Opacity:", | |
min=0.2, | |
max=1, | |
value=0.8, | |
step=0.1 | |
), | |
ui.accordion(ui.accordion_panel( | |
'How it works', ui.HTML("<p>These wealth-index predictions are AI-generated by a" | |
"sequence-aware neural network trained on 30 years of <em>Demographic and Health Surveys (DHS)</em> ground-truth data.</p" | |
"<ul><li>π 57,100+ geo-referenced survey points from DHS</li> <li>βοΈ Multi-spectral satellite bands & raster-to-vector feature extraction</li><li>π― Calibrated & validated with held-out DHS clusters (1990β2019)</li></ul>") | |
), id="map_instructions", open=False, multiple=False), | |
ui.HTML(share_button_html) | |
), | |
ui.layout_column_wrap( | |
ui.value_box( | |
"Highest IWI", | |
ui.output_text("highest_iwi"), | |
showcase=ui.tags.i(class_="fa fa-arrow-up"), | |
theme="success" | |
), | |
ui.value_box( | |
"Lowest IWI", | |
ui.output_text("lowest_iwi"), | |
showcase=ui.tags.i(class_="fa fa-arrow-down"), | |
theme="danger" | |
), | |
ui.value_box( | |
"Average IWI", | |
ui.output_text("avg_iwi"), | |
showcase=ui.tags.i( | |
class_="fa fa-balance-scale"), | |
theme="primary" | |
), | |
width=1/3 | |
), | |
ui.layout_column_wrap( | |
ui.card( | |
ui.card_header( | |
ui.h3("Wealth Map of Africa", class_="title-text")), | |
output_widget("country_map"), | |
ui.p( | |
"Click anywhere on the map to view the time-series of IWI for that specific location (shown below).") | |
), | |
ui.card( | |
ui.card_header( | |
ui.h3("Time Series at Clicked Location", class_="subtitle-text")), | |
ui.output_plot("clicked_ts_plot"), | |
ui.p( | |
"Click on the map to see the full IWI time-series (1990β2019) for that location."), | |
ui.download_button( | |
"download_country_data", "Download CSV", icon="download"), | |
) | |
), | |
ui.card( | |
ui.card_header(ui.h3( | |
"Ground Truth vs. Prediction Residual Distribution (Selected Country)", class_="subtitle-text")), | |
ui.output_plot("iwi_residuals"), | |
ui.p( | |
"This chart shows the distribution of residuals between ground truth and predicted IWI values based on the selected country."), | |
ui.strong( | |
"Note: wealth estimates for areas without human settlements have been excluded from the analysis."), | |
ui.HTML( | |
"<a href='https://doi.org/10.24963/ijcai.2023/684' target='_blank'>[Paper PDF]</a>") | |
) | |
) | |
), | |
ui.nav_panel("Improvement Data", | |
ui.layout_columns( | |
ui.card( | |
ui.card_header( | |
ui.h3("Poverty Improvement by State", class_="title-text")), | |
ui.p("This table shows the estimated improvement in mean IWI between 1990β1992 and 2017β2019 for each province in Africa. " | |
"The 'Improvement' column indicates the change in IWI over this period. You can sort or filter the table, " | |
"and use the download button to export the data."), | |
ui.download_button( | |
"download_data", "Download CSV", icon="download"), | |
ui.card(ui.output_data_frame( | |
"improvement_table")), | |
) | |
) | |
), | |
ui.nav_panel("Trends Over Time", | |
ui.card( | |
ui.card_header( | |
ui.h3("Average Wealth Index Across Africa Over Time", class_="title-text")), | |
ui.p("This chart aggregates the mean IWI across all of Africa in each of the ten time periods. " | |
"It provides a high-level view of how wealth (as measured by IWI) has changed over time."), | |
ui.output_plot("trend_plot") | |
) | |
), | |
title=ui.HTML( | |
"<span style='font-weight: 600; font-size: 16px;'>" | |
"<a href='http://aidevlab.org' target='_blank' " | |
"style='font-family: \"OCR A Std\", monospace; color: white; text-decoration: underline;'>" | |
"aidevlab.org</a></span>" | |
), | |
id="tabs", | |
bg="#337ab7" | |
), | |
) | |
# ------------------------------ | |
# 3. Server logic | |
# ------------------------------ | |
def server(input, output, session): | |
# Initialize the map widget | |
m = Map(center=(0, 20), zoom=3) | |
for feature in band_1_data["features"]: | |
iwi = feature["properties"]["IWI"] | |
feature["properties"]["style"] = { | |
"color": get_color(iwi), | |
"fillColor": get_color(iwi), # Fill color based on IWI | |
"fillOpacity": 0.7, | |
"weight": 1 | |
} | |
band_1_json = GeoJSON(data=band_1_data, | |
style={'radius': 0.05, 'opacity': 0.8, 'weight': 0.5}, | |
point_style={'radius': 0.05}, | |
name='Release' | |
) | |
legend = LegendControl(iwi_mappings, | |
position="bottomleft", | |
title="IWI Values", | |
) | |
# Add the legend to the map | |
m.add_control(legend) | |
m.add_layer(band_1_json) | |
geo_json = GeoJSON( | |
data=country_json, | |
style={ | |
'opacity': 1, 'dashArray': '9', 'fillOpacity': 0.1, 'weight': 1 | |
}, | |
hover_style={ | |
'color': 'white', 'dashArray': '0', 'fillOpacity': 0.5 | |
} | |
) | |
# Register the map widget with Shiny | |
map_widget = register_widget("country_map", m) | |
# Store clicked point values | |
clicked_point_vals = reactive.Value(None) | |
selected_country = reactive.Value(None) | |
admin_layer = reactive.Value(None) | |
selected_admin = reactive.Value(None) | |
# Get the currently selected raster layer | |
def selected_raster(): | |
band_idx = input.time_index() | |
return get_clean_values(wealth_stack, band_idx) | |
# Display selected time period | |
def current_year_range(): | |
# Adjust for 0-based indexing | |
return time_periods[input.time_index() - 1] | |
# Create a Country layer for the map | |
# @reactive.event(input.time_index, input.color_palette, input.opacity) | |
def _(): | |
if input.SelectedMap() == True: | |
m.remove_layer(band_1_json) | |
m.add_layer(geo_json) | |
return m | |
elif input.SelectedMap() == False: | |
for layer in m.layers: | |
if layer == geo_json: | |
m.remove_layer(layer) | |
m.add_layer(band_1_json) | |
# Handle map clicks | |
def _(): | |
# Set up click event handler | |
def handle_map_click(event=None, feature=None, **kwargs): | |
# extract feature coordinates | |
coords = feature['geometry']['coordinates'][0] | |
latitudes = [coords[x][1] for x in range(len(coords))] | |
longitudes = [coords[y][0] for y in range(len(coords))] | |
# find country name | |
country_name = feature['properties']['sovereignt'] | |
# find country abbreviation | |
country_abbrev = feature['properties']['sov_a3'] | |
selected_country.set(country_name) # set the country name | |
# lock view position to the country's centroid | |
centroid = (np.mean(latitudes), np.mean(longitudes)) | |
m.center = centroid | |
m.zoom = 5 | |
# Register click handler | |
geo_json.on_click(handle_map_click) | |
# Display value boxes | |
def highest_iwi(): | |
raster_data = selected_raster() | |
return f"{np.nanmax(raster_data):.3f}" | |
def lowest_iwi(): | |
raster_data = selected_raster() | |
return f"{np.nanmin(raster_data):.3f}" | |
def avg_iwi(): | |
raster_data = selected_raster() | |
return f"{np.nanmean(raster_data):.3f}" | |
# Generate trend plot for mean IWI across Africa | |
def trend_plot(): | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
ax.plot(range(len(time_periods)), band_means, marker='o', | |
color="darkorange", linewidth=2, markersize=6) | |
ax.set_xticks(range(len(time_periods))) | |
ax.set_xticklabels(time_periods, rotation=45, ha="right") | |
ax.set_ylabel("Mean IWI") | |
ax.set_ylim(0.1, 0.3) | |
ax.set_title("Average IWI Over Time (Africa)") | |
ax.grid(True, linestyle='--', alpha=0.7) | |
plt.tight_layout() | |
return fig | |
# Generate histogram plot | |
def iwi_residuals(): | |
country_name = selected_country.get() | |
fig = get_residual_plot(country_name, residual_data) | |
return fig | |
# Plot time series at clicked location | |
def clicked_ts_plot(): | |
country_name = selected_country.get() | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
if country_name is None: | |
ax.text(0.5, 0.5, "Click on the map to see the IWI time-series here.", | |
horizontalalignment='center', verticalalignment='center', | |
transform=ax.transAxes, fontsize=14) | |
else: | |
ax.plot(IWI_df['Band_Number'], IWI_df[country_name], | |
marker='o', color="darkorange", linewidth=2, markersize=6) | |
ax.set_xticks(range(1, len(IWI_df['Band_Number'])+1)) | |
ax.set_xticklabels(time_periods, rotation=45) | |
ax.set_ylabel("IWI (0 to 1)") | |
ax.set_ylim(0, 1) | |
ax.set_title(f"Time Series of IWI in {country_name}") | |
ax.grid(True, linestyle='--', alpha=0.7) | |
plt.tight_layout() | |
return fig | |
# Display improvement data table | |
def improvement_table(): | |
return render.DataGrid( | |
improvement_data, | |
filters=True, | |
height="800px" | |
) | |
# Download CSV handler | |
def download_data(): | |
return improvement_data.to_csv(index=False) | |
async def download_country_data(): | |
country_name = selected_country.get() | |
buf = io.StringIO() | |
country_data = pd.DataFrame(IWI_df[country_name]) | |
country_data.to_csv(buf, index=False) | |
yield buf.getvalue() | |
# ------------------------------ | |
# 4. Create and run the app | |
# ------------------------------ | |
app = App(app_ui, server) | |