# interactive_plot_generator.py # Generate interactive air pollution maps for India with hover information import numpy as np import plotly.graph_objects as go import plotly.express as px import geopandas as gpd from pathlib import Path from datetime import datetime from constants import INDIA_BOUNDS, COLOR_THEMES import plotly.io as pio import warnings warnings.filterwarnings('ignore') class InteractiveIndiaMapPlotter: def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"): """ Initialize the interactive map plotter Parameters: plots_dir (str): Directory to save plots shapefile_path (str): Path to the India districts shapefile """ self.plots_dir = Path(plots_dir) self.plots_dir.mkdir(exist_ok=True) try: self.india_map = gpd.read_file(shapefile_path) # Ensure it's in lat/lon (WGS84) if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326: self.india_map = self.india_map.to_crs(epsg=4326) except Exception as e: raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. " f"Please ensure the file exists. Error: {e}") def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None): """ Create interactive air pollution map over India with hover information Parameters: data_values (np.ndarray): 2D array of pollution data metadata (dict): Metadata containing lats, lons, variable info, etc. color_theme (str): Color theme name from COLOR_THEMES save_plot (bool): Whether to save the plot as HTML and PNG custom_title (str): Custom title for the plot Returns: dict: Dictionary containing paths to saved files and HTML content - 'html_path': Path to interactive HTML file - 'png_path': Path to static PNG file - 'html_content': HTML content for embedding """ try: # Extract metadata lats = metadata['lats'] lons = metadata['lons'] var_name = metadata['variable_name'] display_name = metadata['display_name'] units = metadata['units'] pressure_level = metadata.get('pressure_level') time_stamp = metadata.get('timestamp_str') # Determine color theme if color_theme is None: from constants import AIR_POLLUTION_VARIABLES color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis') # Map matplotlib colormaps to Plotly colormaps # This mapping ensures all COLOR_THEMES from constants.py are supported colormap_mapping = { # Sequential color schemes 'viridis': 'Viridis', 'plasma': 'Plasma', 'inferno': 'Inferno', 'magma': 'Magma', 'cividis': 'Cividis', # Single-hue sequential schemes 'YlOrRd': 'YlOrRd', 'Oranges': 'Oranges', 'Reds': 'Reds', 'Purples': 'Purples', 'Blues': 'Blues', 'Greens': 'Greens', # Diverging schemes 'coolwarm': 'RdBu_r', 'RdYlBu': 'RdYlBu', 'Spectral': 'Spectral', 'Spectral_r': 'Spectral_r', 'RdYlGn_r': 'RdYlGn_r', # Other schemes 'jet': 'Jet', 'turbo': 'Turbo' } plotly_colorscale = colormap_mapping.get(color_theme, 'Viridis') # Create mesh grid if needed if lons.ndim == 1 and lats.ndim == 1: lon_grid, lat_grid = np.meshgrid(lons, lats) else: lon_grid, lat_grid = lons, lats # Calculate statistics valid_data = data_values[~np.isnan(data_values)] if len(valid_data) == 0: raise ValueError("All data values are NaN - cannot create plot") from constants import AIR_POLLUTION_VARIABLES vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90) vmin = np.nanpercentile(valid_data, 5) vmax = np.nanpercentile(valid_data, vmax_percentile) if vmax <= vmin: vmax = vmin + 1.0 # Create hover text with detailed information hover_text = self._create_hover_text(lon_grid, lat_grid, data_values, display_name, units) # Create the figure fig = go.Figure() # Add pollution data as heatmap fig.add_trace(go.Heatmap( x=lons, y=lats, z=data_values, colorscale=plotly_colorscale, zmin=vmin, zmax=vmax, hovertext=hover_text, hoverinfo='text', colorbar=dict( title=dict( text=f"{display_name}" + (f"
({units})" if units else ""), side="right" ), thickness=20, len=0.6, x=1.02 ) )) # Add India state boundaries for _, row in self.india_map.iterrows(): if row.geometry.geom_type == 'Polygon': self._add_polygon_trace(fig, row.geometry) elif row.geometry.geom_type == 'MultiPolygon': for polygon in row.geometry.geoms: self._add_polygon_trace(fig, polygon) # Create title - include pressure level and plot type if custom_title: title = custom_title else: title = f'{display_name} Concentration over India (Interactive)' if pressure_level: title += f' at {pressure_level} hPa' title += f' on {time_stamp}' # Calculate stats for annotation stats_text = self._create_stats_text(valid_data, units) theme_name = COLOR_THEMES.get(color_theme, color_theme) # Auto-adjust bounds if needed xmin, ymin, xmax, ymax = self.india_map.total_bounds if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max']): lon_range = [xmin, xmax] lat_range = [ymin, ymax] else: lon_range = [INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max']] lat_range = [INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max']] # Update layout for better interactivity fig.update_layout( title=dict( text=title, x=0.5, xanchor='center', font=dict(size=18, weight='bold') ), xaxis=dict( title='Longitude', range=lon_range, showgrid=True, gridcolor='rgba(128, 128, 128, 0.3)', zeroline=False ), yaxis=dict( title='Latitude', range=lat_range, showgrid=True, gridcolor='rgba(128, 128, 128, 0.3)', zeroline=False, scaleanchor="x", scaleratio=1 # Simplified to match static plot aspect ratio ), width=1400, height=1000, plot_bgcolor='white', # Enable zoom, pan and other interactive features dragmode='zoom', showlegend=False, hovermode='closest', # Add modebar with download options modebar=dict( bgcolor='rgba(255, 255, 255, 0.8)', activecolor='rgb(0, 123, 255)', orientation='h' ), annotations=[ # Statistics box dict( text=stats_text.replace('\n', '
'), xref='paper', yref='paper', x=0.02, y=0.98, xanchor='left', yanchor='top', showarrow=False, bgcolor='rgba(255, 255, 255, 0.9)', bordercolor='black', borderwidth=1, borderpad=10, font=dict(size=11) ), # Theme info box dict( text=f'Color Theme: {theme_name}', xref='paper', yref='paper', x=0.98, y=0.02, xanchor='right', yanchor='bottom', showarrow=False, bgcolor='rgba(211, 211, 211, 0.8)', bordercolor='gray', borderwidth=1, borderpad=8, font=dict(size=10) ), # Instructions dict( text='๐Ÿ” Zoom: Mouse wheel or zoom tool | ๐Ÿ“ Hover: Show coordinates & values | ๐Ÿ“ฅ Download: Camera icon', xref='paper', yref='paper', x=0.5, y=0.02, xanchor='center', yanchor='bottom', showarrow=False, bgcolor='rgba(173, 216, 230, 0.8)', bordercolor='steelblue', borderwidth=1, borderpad=8, font=dict(size=10, color='darkblue') ) ] ) # Configure the figure for better interactivity and downloads config = { 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToAdd': [ 'drawline', 'drawopenpath', 'drawclosedpath', 'drawcircle', 'drawrect', 'eraseshape' ], 'modeBarButtonsToRemove': ['lasso2d', 'select2d'], 'toImageButtonOptions': { 'format': 'png', 'filename': f'india_pollution_map_{datetime.now().strftime("%Y%m%d_%H%M%S")}', 'height': 1000, 'width': 1400, 'scale': 2 }, 'responsive': True } # Save files if requested result = {'html_content': None, 'html_path': None, 'png_path': None} if save_plot: # Generate HTML content for embedding html_content = pio.to_html( fig, config=config, include_plotlyjs='cdn', div_id='interactive-plot', full_html=False ) result['html_content'] = html_content # Save as HTML file html_path = self._save_html_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp, config) result['html_path'] = html_path # Save as PNG for fallback (only if kaleido works) png_path = self._save_png_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp) result['png_path'] = png_path else: # Just return HTML content for display html_content = pio.to_html( fig, config=config, include_plotlyjs='cdn', div_id='interactive-plot', full_html=False ) result['html_content'] = html_content return result except Exception as e: raise Exception(f"Error creating interactive map: {str(e)}") def _add_polygon_trace(self, fig, polygon): """Add a polygon boundary to the figure""" x, y = polygon.exterior.xy fig.add_trace(go.Scatter( x=list(x), y=list(y), mode='lines', line=dict(color='black', width=1), hoverinfo='skip', showlegend=False )) def _create_hover_text(self, lon_grid, lat_grid, data_values, display_name, units): """Create formatted hover text for each point""" hover_text = np.empty(data_values.shape, dtype=object) units_str = f" {units}" if units else "" for i in range(data_values.shape[0]): for j in range(data_values.shape[1]): lat = lat_grid[i, j] if lat_grid.ndim == 2 else lat_grid[i] lon = lon_grid[i, j] if lon_grid.ndim == 2 else lon_grid[j] value = data_values[i, j] if np.isnan(value): value_str = "N/A" elif abs(value) >= 1000: value_str = f"{value:.0f}{units_str}" elif abs(value) >= 10: value_str = f"{value:.1f}{units_str}" else: value_str = f"{value:.2f}{units_str}" hover_text[i, j] = ( f"{display_name}: {value_str}
" f"Latitude: {lat:.3f}ยฐ
" f"Longitude: {lon:.3f}ยฐ" ) return hover_text def _create_stats_text(self, data, units): """Create statistics text for annotation""" units_str = f" {units}" if units else "" stats = { 'Min': np.nanmin(data), 'Max': np.nanmax(data), 'Mean': np.nanmean(data), 'Median': np.nanmedian(data), 'Std': np.nanstd(data) } def format_number(val): if abs(val) >= 1000: return f"{val:.0f}" elif abs(val) >= 10: return f"{val:.1f}" else: return f"{val:.2f}" stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()] return "\n".join(stats_lines) def _save_html_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp, config): """Save the interactive plot as HTML""" # Handle None values with fallbacks display_name = display_name or var_name or 'Unknown' time_stamp = time_stamp or 'Unknown_Time' safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('โ‚‚', '2').replace('โ‚ƒ', '3').replace('.', '_') safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_') filename_parts = [f"{safe_display_name}_India_interactive"] if pressure_level: filename_parts.append(f"{int(pressure_level)}hPa") filename_parts.extend([color_theme, safe_time_stamp]) filename = "_".join(filename_parts) + ".html" plot_path = self.plots_dir / filename # Save as interactive HTML fig.write_html(str(plot_path), config=config, include_plotlyjs='cdn') print(f"Interactive HTML plot saved: {plot_path}") return str(plot_path) def _save_png_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp): """Save the plot as PNG for download/fallback""" safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('โ‚‚', '2').replace('โ‚ƒ', '3').replace('.', '_') safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_') filename_parts = [f"{safe_display_name}_India_static"] if pressure_level: filename_parts.append(f"{int(pressure_level)}hPa") filename_parts.extend([color_theme, safe_time_stamp]) filename = "_".join(filename_parts) + ".png" plot_path = self.plots_dir / filename try: # Save as static PNG with high quality fig.write_image(str(plot_path), format='png', width=1400, height=1000, scale=2) print(f"Static PNG plot saved: {plot_path}") return str(plot_path) except Exception as e: print(f"Warning: Could not save PNG: {e}") return None def list_available_themes(self): """List available color themes""" return COLOR_THEMES def test_interactive_plot_generator(): """Test function for the interactive plot generator""" print("Testing interactive plot generator...") # Create test data lats = np.linspace(6, 38, 50) lons = np.linspace(68, 98, 60) lon_grid, lat_grid = np.meshgrid(lons, lats) data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50 data += np.random.normal(0, 10, data.shape) metadata = { 'variable_name': 'pm25', 'display_name': 'PM2.5', 'units': 'ยตg/mยณ', 'lats': lats, 'lons': lons, 'pressure_level': None, 'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), } shapefile_path = "shapefiles/India_State_Boundary.shp" if not Path(shapefile_path).exists(): print(f"โŒ Test failed: Shapefile not found at '{shapefile_path}'.") print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.") return False plotter = InteractiveIndiaMapPlotter(shapefile_path=shapefile_path) try: result = plotter.create_india_map(data, metadata, color_theme='YlOrRd') if result.get('html_path'): print(f"โœ… Test interactive HTML plot created successfully: {result['html_path']}") if result.get('png_path'): print(f"โœ… Test static PNG plot created successfully: {result['png_path']}") return True except Exception as e: print(f"โŒ Test failed: {str(e)}") import traceback traceback.print_exc() return False def test_color_themes(): """Test all available color themes for compatibility""" from constants import COLOR_THEMES # Create colormap mapping colormap_mapping = { # Sequential color schemes 'viridis': 'Viridis', 'plasma': 'Plasma', 'inferno': 'Inferno', 'magma': 'Magma', 'cividis': 'Cividis', # Single-hue sequential schemes 'YlOrRd': 'YlOrRd', 'Oranges': 'Oranges', 'Reds': 'Reds', 'Purples': 'Purples', 'Blues': 'Blues', 'Greens': 'Greens', # Diverging schemes 'coolwarm': 'RdBu_r', 'RdYlBu': 'RdYlBu', 'Spectral': 'Spectral', 'Spectral_r': 'Spectral_r', 'RdYlGn_r': 'RdYlGn_r', # Other schemes 'jet': 'Jet', 'turbo': 'Turbo' } print("๐ŸŽจ Testing color theme mappings:") print(f"{'Color Theme':<15} {'Plotly Colorscale':<20} {'Status'}") print("-" * 50) for theme_key in COLOR_THEMES.keys(): if theme_key in colormap_mapping: plotly_scale = colormap_mapping[theme_key] status = "โœ… Mapped" else: plotly_scale = "Viridis (default)" status = "โš ๏ธ Missing" print(f"{theme_key:<15} {plotly_scale:<20} {status}") missing_themes = set(COLOR_THEMES.keys()) - set(colormap_mapping.keys()) if missing_themes: print(f"\nโŒ Missing mappings for: {', '.join(missing_themes)}") return False else: print(f"\nโœ… All {len(COLOR_THEMES)} color themes are properly mapped!") return True if __name__ == "__main__": test_interactive_plot_generator()