cams-pollution-dashboard / plot_generator.py
aditya-me13's picture
Integration of Aurora
4f0125c
raw
history blame
10.9 kB
# plot_generator.py
# Generate air pollution maps for India using GeoPandas for the map outline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for web apps
import geopandas as gpd
from pathlib import Path
from datetime import datetime
from constants import INDIA_BOUNDS, COLOR_THEMES
import warnings
warnings.filterwarnings('ignore')
class IndiaMapPlotter:
def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"):
"""
Initialize the map plotter
Parameters:
plots_dir (str): Directory to save plots
shapefile_path (str): Path to India boundary shapefile
"""
self.plots_dir = Path(plots_dir)
self.plots_dir.mkdir(exist_ok=True)
try:
self.india_map = gpd.read_file(shapefile_path)
# Ensure it's in lat/lon (WGS84)
if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326:
self.india_map = self.india_map.to_crs(epsg=4326)
except Exception as e:
raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. "
f"Please ensure the file exists. Error: {e}")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['font.size'] = 10
def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None):
"""
Create air pollution map over India
"""
try:
# Metadata extraction remains the same
lats = metadata['lats']
lons = metadata['lons']
var_name = metadata['variable_name']
display_name = metadata['display_name']
units = metadata['units']
pressure_level = metadata.get('pressure_level')
time_stamp = metadata.get('timestamp_str')
# Color theme logic remains the same
if color_theme is None:
from constants import AIR_POLLUTION_VARIABLES
color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis')
if color_theme not in COLOR_THEMES:
print(f"Warning: Color theme '{color_theme}' not found, using 'viridis'")
color_theme = 'viridis'
# Create figure and axes - match interactive plot proportions (1400x1000 = 1.4:1 ratio)
fig = plt.figure(figsize=(16, 10)) # Wider to match interactive plot
ax = fig.add_subplot(1, 1, 1)
# Set map extent
ax.set_xlim(INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max'])
ax.set_ylim(INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max'])
# --- KEY CHANGE: PLOT ORDER & ZORDER ---
# 1. Plot the pollution data in the background (lower zorder) - pixel-wise like interactive plots
if lons.ndim == 1 and lats.ndim == 1:
lon_grid, lat_grid = np.meshgrid(lons, lats)
else:
lon_grid, lat_grid = lons, lats
valid_data = data_values[~np.isnan(data_values)]
if len(valid_data) == 0:
raise ValueError("All data values are NaN - cannot create plot")
from constants import AIR_POLLUTION_VARIABLES
vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90)
vmin = np.nanpercentile(valid_data, 5)
vmax = np.nanpercentile(valid_data, vmax_percentile)
if vmax <= vmin:
vmax = vmin + 1.0
# Use imshow for pixel-wise display - matches interactive plot orientation
extent = [lons.min(), lons.max(), lats.min(), lats.max()]
# Handle latitude order for proper orientation
# NetCDF files often have descending latitudes, but imshow with origin='lower' expects ascending
lat_ascending = lats[0] < lats[-1] if len(lats) > 1 else True
if lat_ascending:
# Lats are ascending (good for origin='lower')
plot_data = data_values
else:
# Lats are descending, flip to match origin='lower'
plot_data = np.flipud(data_values)
im = ax.imshow(plot_data, cmap=color_theme, vmin=vmin, vmax=vmax,
extent=extent, origin='lower', aspect='auto', # Changed to 'auto' to match interactive plot
interpolation='nearest', zorder=1)
# Auto-adjust bounds if INDIA_BOUNDS is too small or wrong
xmin, ymin, xmax, ymax = self.india_map.total_bounds
if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max'] and INDIA_BOUNDS['lon_min'] <= xmax <= INDIA_BOUNDS['lon_max']):
print("⚠️ Warning: Using shapefile's actual bounds instead of INDIA_BOUNDS.")
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
# 2. Plot the India map outlines on top of the data (higher zorder)
self.india_map.plot(ax=ax, edgecolor='black', facecolor='none',
linewidth=0.8, zorder=2) # <-- CHANGED: Set zorder=2 (foreground)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, shrink=0.6, pad=0.02, aspect=30)
cbar_label = f"{display_name}" + (f" ({units})" if units else "")
cbar.set_label(cbar_label, fontsize=12, labelpad=15)
# Add gridlines and labels
ax.grid(True, linestyle='--', alpha=0.6, color='gray', zorder=3)
ax.set_xlabel("Longitude", fontsize=10)
ax.set_ylabel("Latitude", fontsize=10)
ax.tick_params(axis='both', which='major', labelsize=10)
# Title creation logic - include pressure level and plot type
if custom_title:
title = custom_title
else:
title = f'{display_name} Concentration over India (Static)'
if pressure_level:
title += f' at {pressure_level} hPa'
title += f' on {time_stamp}'
plt.title(title, fontsize=14, pad=20, weight='bold')
# Statistics and theme info boxes remain the same
stats_text = self._create_stats_text(valid_data, units)
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.9),
verticalalignment='top', fontsize=10, zorder=4)
theme_text = f"Color Theme: {COLOR_THEMES[color_theme]}"
ax.text(0.98, 0.02, theme_text, transform=ax.transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8),
horizontalalignment='right', verticalalignment='bottom', fontsize=9, zorder=4)
plt.tight_layout()
plot_path = None
if save_plot:
plot_path = self._save_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp)
plt.close(fig)
return plot_path
except Exception as e:
plt.close('all')
raise Exception(f"Error creating map: {str(e)}")
# All other helper methods (_create_stats_text, _save_plot, etc.) are unchanged.
# The `create_comparison_plot` method is also left out for brevity but would need the same zorder fix.
# The full, unchanged code for the helper methods from the previous answer is still valid.
def _create_stats_text(self, data, units):
units_str = f" {units}" if units else ""
stats = {'Min': np.nanmin(data), 'Max': np.nanmax(data), 'Mean': np.nanmean(data), 'Median': np.nanmedian(data), 'Std': np.nanstd(data)}
def format_number(val):
if abs(val) >= 1000: return f"{val:.0f}"
elif abs(val) >= 10: return f"{val:.1f}"
else: return f"{val:.2f}"
stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()]
return "\n".join(stats_lines)
def _save_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp):
# Handle None values with fallbacks
display_name = display_name or var_name or 'Unknown'
time_stamp = time_stamp or 'Unknown_Time'
safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('₂', '2').replace('₃', '3').replace('.', '_')
safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_')
filename_parts = [f"{safe_display_name}_India"]
if pressure_level:
filename_parts.append(f"{int(pressure_level)}hPa")
filename_parts.extend([color_theme, safe_time_stamp])
filename = "_".join(filename_parts) + ".png"
plot_path = self.plots_dir / filename
fig.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
print(f"Plot saved: {plot_path}")
return str(plot_path)
def list_available_themes(self):
return COLOR_THEMES
def test_plot_generator():
print("Testing plot generator with GeoPandas and zorder fix...")
lats, lons = np.linspace(6, 38, 50), np.linspace(68, 98, 60)
lon_grid, lat_grid = np.meshgrid(lons, lats)
data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50
data += np.random.normal(0, 10, data.shape)
metadata = {
'variable_name': 'pm25', 'display_name': 'PM2.5', 'units': 'µg/m³',
'lats': lats, 'lons': lons, 'pressure_level': None,
'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
shapefile_path = "shapefiles/India_State_Boundary.shp"
if not Path(shapefile_path).exists():
print(f"❌ Test failed: Shapefile not found at '{shapefile_path}'.")
print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.")
return False
plotter = IndiaMapPlotter(shapefile_path=shapefile_path)
try:
plot_path = plotter.create_india_map(data, metadata, color_theme='YlOrRd')
print(f"✅ Test plot created successfully: {plot_path}")
return True
except Exception as e:
print(f"❌ Test failed: {str(e)}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
test_plot_generator()