Spaces:
Sleeping
Sleeping
| # interactive_plot_generator.py | |
| # Generate interactive air pollution maps for India with hover information | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import geopandas as gpd | |
| from pathlib import Path | |
| from datetime import datetime | |
| from constants import INDIA_BOUNDS, COLOR_THEMES | |
| import plotly.io as pio | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class InteractiveIndiaMapPlotter: | |
| def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"): | |
| """ | |
| Initialize the interactive map plotter | |
| Parameters: | |
| plots_dir (str): Directory to save plots | |
| shapefile_path (str): Path to the India districts shapefile | |
| """ | |
| self.plots_dir = Path(plots_dir) | |
| self.plots_dir.mkdir(exist_ok=True) | |
| try: | |
| self.india_map = gpd.read_file(shapefile_path) | |
| # Ensure it's in lat/lon (WGS84) | |
| if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326: | |
| self.india_map = self.india_map.to_crs(epsg=4326) | |
| except Exception as e: | |
| raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. " | |
| f"Please ensure the file exists. Error: {e}") | |
| def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None): | |
| """ | |
| Create interactive air pollution map over India with hover information | |
| Parameters: | |
| data_values (np.ndarray): 2D array of pollution data | |
| metadata (dict): Metadata containing lats, lons, variable info, etc. | |
| color_theme (str): Color theme name from COLOR_THEMES | |
| save_plot (bool): Whether to save the plot as HTML and PNG | |
| custom_title (str): Custom title for the plot | |
| Returns: | |
| dict: Dictionary containing paths to saved files and HTML content | |
| - 'html_path': Path to interactive HTML file | |
| - 'png_path': Path to static PNG file | |
| - 'html_content': HTML content for embedding | |
| """ | |
| try: | |
| # Extract metadata | |
| lats = metadata['lats'] | |
| lons = metadata['lons'] | |
| var_name = metadata['variable_name'] | |
| display_name = metadata['display_name'] | |
| units = metadata['units'] | |
| pressure_level = metadata.get('pressure_level') | |
| time_stamp = metadata.get('timestamp_str') | |
| # Determine color theme | |
| if color_theme is None: | |
| from constants import AIR_POLLUTION_VARIABLES | |
| color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis') | |
| # Map matplotlib colormaps to Plotly colormaps | |
| # This mapping ensures all COLOR_THEMES from constants.py are supported | |
| colormap_mapping = { | |
| # Sequential color schemes | |
| 'viridis': 'Viridis', | |
| 'plasma': 'Plasma', | |
| 'inferno': 'Inferno', | |
| 'magma': 'Magma', | |
| 'cividis': 'Cividis', | |
| # Single-hue sequential schemes | |
| 'YlOrRd': 'YlOrRd', | |
| 'Oranges': 'Oranges', | |
| 'Reds': 'Reds', | |
| 'Purples': 'Purples', | |
| 'Blues': 'Blues', | |
| 'Greens': 'Greens', | |
| # Diverging schemes | |
| 'coolwarm': 'RdBu_r', | |
| 'RdYlBu': 'RdYlBu', | |
| 'Spectral': 'Spectral', | |
| 'Spectral_r': 'Spectral_r', | |
| 'RdYlGn_r': 'RdYlGn_r', | |
| # Other schemes | |
| 'jet': 'Jet', | |
| 'turbo': 'Turbo' | |
| } | |
| plotly_colorscale = colormap_mapping.get(color_theme, 'Viridis') | |
| # Create mesh grid if needed | |
| if lons.ndim == 1 and lats.ndim == 1: | |
| lon_grid, lat_grid = np.meshgrid(lons, lats) | |
| else: | |
| lon_grid, lat_grid = lons, lats | |
| # Calculate statistics | |
| valid_data = data_values[~np.isnan(data_values)] | |
| if len(valid_data) == 0: | |
| raise ValueError("All data values are NaN - cannot create plot") | |
| from constants import AIR_POLLUTION_VARIABLES | |
| vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90) | |
| vmin = np.nanpercentile(valid_data, 5) | |
| vmax = np.nanpercentile(valid_data, vmax_percentile) | |
| if vmax <= vmin: | |
| vmax = vmin + 1.0 | |
| # Create hover text with detailed information | |
| hover_text = self._create_hover_text(lon_grid, lat_grid, data_values, display_name, units) | |
| # Create the figure | |
| fig = go.Figure() | |
| # Add pollution data as heatmap | |
| fig.add_trace(go.Heatmap( | |
| x=lons, | |
| y=lats, | |
| z=data_values, | |
| colorscale=plotly_colorscale, | |
| zmin=vmin, | |
| zmax=vmax, | |
| hovertext=hover_text, | |
| hoverinfo='text', | |
| colorbar=dict( | |
| title=dict( | |
| text=f"{display_name}" + (f"<br>({units})" if units else ""), | |
| side="right" | |
| ), | |
| thickness=20, | |
| len=0.6, | |
| x=1.02 | |
| ) | |
| )) | |
| # Add India state boundaries | |
| for _, row in self.india_map.iterrows(): | |
| if row.geometry.geom_type == 'Polygon': | |
| self._add_polygon_trace(fig, row.geometry) | |
| elif row.geometry.geom_type == 'MultiPolygon': | |
| for polygon in row.geometry.geoms: | |
| self._add_polygon_trace(fig, polygon) | |
| # Create title - include pressure level and plot type | |
| if custom_title: | |
| title = custom_title | |
| else: | |
| title = f'{display_name} Concentration over India (Interactive)' | |
| if pressure_level: | |
| title += f' at {pressure_level} hPa' | |
| title += f' on {time_stamp}' | |
| # Calculate stats for annotation | |
| stats_text = self._create_stats_text(valid_data, units) | |
| theme_name = COLOR_THEMES.get(color_theme, color_theme) | |
| # Auto-adjust bounds if needed | |
| xmin, ymin, xmax, ymax = self.india_map.total_bounds | |
| if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max']): | |
| lon_range = [xmin, xmax] | |
| lat_range = [ymin, ymax] | |
| else: | |
| lon_range = [INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max']] | |
| lat_range = [INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max']] | |
| # Update layout for better interactivity | |
| fig.update_layout( | |
| title=dict( | |
| text=title, | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18, weight='bold') | |
| ), | |
| xaxis=dict( | |
| title='Longitude', | |
| range=lon_range, | |
| showgrid=True, | |
| gridcolor='rgba(128, 128, 128, 0.3)', | |
| zeroline=False | |
| ), | |
| yaxis=dict( | |
| title='Latitude', | |
| range=lat_range, | |
| showgrid=True, | |
| gridcolor='rgba(128, 128, 128, 0.3)', | |
| zeroline=False, | |
| scaleanchor="x", | |
| scaleratio=1 # Simplified to match static plot aspect ratio | |
| ), | |
| width=1400, | |
| height=1000, | |
| plot_bgcolor='white', | |
| # Enable zoom, pan and other interactive features | |
| dragmode='zoom', | |
| showlegend=False, | |
| hovermode='closest', | |
| # Add modebar with download options | |
| modebar=dict( | |
| bgcolor='rgba(255, 255, 255, 0.8)', | |
| activecolor='rgb(0, 123, 255)', | |
| orientation='h' | |
| ), | |
| annotations=[ | |
| # Statistics box | |
| dict( | |
| text=stats_text.replace('\n', '<br>'), | |
| xref='paper', yref='paper', | |
| x=0.02, y=0.98, | |
| xanchor='left', yanchor='top', | |
| showarrow=False, | |
| bgcolor='rgba(255, 255, 255, 0.9)', | |
| bordercolor='black', | |
| borderwidth=1, | |
| borderpad=10, | |
| font=dict(size=11) | |
| ), | |
| # Theme info box | |
| dict( | |
| text=f'Color Theme: {theme_name}', | |
| xref='paper', yref='paper', | |
| x=0.98, y=0.02, | |
| xanchor='right', yanchor='bottom', | |
| showarrow=False, | |
| bgcolor='rgba(211, 211, 211, 0.8)', | |
| bordercolor='gray', | |
| borderwidth=1, | |
| borderpad=8, | |
| font=dict(size=10) | |
| ), | |
| # Instructions | |
| dict( | |
| text='🔍 Zoom: Mouse wheel or zoom tool | 📍 Hover: Show coordinates & values | 📥 Download: Camera icon', | |
| xref='paper', yref='paper', | |
| x=0.5, y=0.02, | |
| xanchor='center', yanchor='bottom', | |
| showarrow=False, | |
| bgcolor='rgba(173, 216, 230, 0.8)', | |
| bordercolor='steelblue', | |
| borderwidth=1, | |
| borderpad=8, | |
| font=dict(size=10, color='darkblue') | |
| ) | |
| ] | |
| ) | |
| # Configure the figure for better interactivity and downloads | |
| config = { | |
| 'displayModeBar': True, | |
| 'displaylogo': False, | |
| 'modeBarButtonsToAdd': [ | |
| 'drawline', | |
| 'drawopenpath', | |
| 'drawclosedpath', | |
| 'drawcircle', | |
| 'drawrect', | |
| 'eraseshape' | |
| ], | |
| 'modeBarButtonsToRemove': ['lasso2d', 'select2d'], | |
| 'toImageButtonOptions': { | |
| 'format': 'png', | |
| 'filename': f'india_pollution_map_{datetime.now().strftime("%Y%m%d_%H%M%S")}', | |
| 'height': 1000, | |
| 'width': 1400, | |
| 'scale': 2 | |
| }, | |
| 'responsive': True | |
| } | |
| # Save files if requested | |
| result = {'html_content': None, 'html_path': None, 'png_path': None} | |
| if save_plot: | |
| # Generate HTML content for embedding | |
| html_content = pio.to_html( | |
| fig, | |
| config=config, | |
| include_plotlyjs='cdn', | |
| div_id='interactive-plot', | |
| full_html=False | |
| ) | |
| result['html_content'] = html_content | |
| # Save as HTML file | |
| html_path = self._save_html_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp, config) | |
| result['html_path'] = html_path | |
| # Save as PNG for fallback (only if kaleido works) | |
| png_path = self._save_png_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp) | |
| result['png_path'] = png_path | |
| else: | |
| # Just return HTML content for display | |
| html_content = pio.to_html( | |
| fig, | |
| config=config, | |
| include_plotlyjs='cdn', | |
| div_id='interactive-plot', | |
| full_html=False | |
| ) | |
| result['html_content'] = html_content | |
| return result | |
| except Exception as e: | |
| raise Exception(f"Error creating interactive map: {str(e)}") | |
| def _add_polygon_trace(self, fig, polygon): | |
| """Add a polygon boundary to the figure""" | |
| x, y = polygon.exterior.xy | |
| fig.add_trace(go.Scatter( | |
| x=list(x), | |
| y=list(y), | |
| mode='lines', | |
| line=dict(color='black', width=1), | |
| hoverinfo='skip', | |
| showlegend=False | |
| )) | |
| def _create_hover_text(self, lon_grid, lat_grid, data_values, display_name, units): | |
| """Create formatted hover text for each point""" | |
| hover_text = np.empty(data_values.shape, dtype=object) | |
| units_str = f" {units}" if units else "" | |
| for i in range(data_values.shape[0]): | |
| for j in range(data_values.shape[1]): | |
| lat = lat_grid[i, j] if lat_grid.ndim == 2 else lat_grid[i] | |
| lon = lon_grid[i, j] if lon_grid.ndim == 2 else lon_grid[j] | |
| value = data_values[i, j] | |
| if np.isnan(value): | |
| value_str = "N/A" | |
| elif abs(value) >= 1000: | |
| value_str = f"{value:.0f}{units_str}" | |
| elif abs(value) >= 10: | |
| value_str = f"{value:.1f}{units_str}" | |
| else: | |
| value_str = f"{value:.2f}{units_str}" | |
| hover_text[i, j] = ( | |
| f"<b>{display_name}</b>: {value_str}<br>" | |
| f"<b>Latitude</b>: {lat:.3f}°<br>" | |
| f"<b>Longitude</b>: {lon:.3f}°" | |
| ) | |
| return hover_text | |
| def _create_stats_text(self, data, units): | |
| """Create statistics text for annotation""" | |
| units_str = f" {units}" if units else "" | |
| stats = { | |
| 'Min': np.nanmin(data), | |
| 'Max': np.nanmax(data), | |
| 'Mean': np.nanmean(data), | |
| 'Median': np.nanmedian(data), | |
| 'Std': np.nanstd(data) | |
| } | |
| def format_number(val): | |
| if abs(val) >= 1000: | |
| return f"{val:.0f}" | |
| elif abs(val) >= 10: | |
| return f"{val:.1f}" | |
| else: | |
| return f"{val:.2f}" | |
| stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()] | |
| return "\n".join(stats_lines) | |
| def _save_html_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp, config): | |
| """Save the interactive plot as HTML""" | |
| # Handle None values with fallbacks | |
| display_name = display_name or var_name or 'Unknown' | |
| time_stamp = time_stamp or 'Unknown_Time' | |
| safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('₂', '2').replace('₃', '3').replace('.', '_') | |
| safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_') | |
| filename_parts = [f"{safe_display_name}_India_interactive"] | |
| if pressure_level: | |
| filename_parts.append(f"{int(pressure_level)}hPa") | |
| filename_parts.extend([color_theme, safe_time_stamp]) | |
| filename = "_".join(filename_parts) + ".html" | |
| plot_path = self.plots_dir / filename | |
| # Save as interactive HTML | |
| fig.write_html(str(plot_path), config=config, include_plotlyjs='cdn') | |
| print(f"Interactive HTML plot saved: {plot_path}") | |
| return str(plot_path) | |
| def _save_png_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp): | |
| """Save the plot as PNG for download/fallback""" | |
| safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('₂', '2').replace('₃', '3').replace('.', '_') | |
| safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_') | |
| filename_parts = [f"{safe_display_name}_India_static"] | |
| if pressure_level: | |
| filename_parts.append(f"{int(pressure_level)}hPa") | |
| filename_parts.extend([color_theme, safe_time_stamp]) | |
| filename = "_".join(filename_parts) + ".png" | |
| plot_path = self.plots_dir / filename | |
| try: | |
| # Save as static PNG with high quality | |
| fig.write_image(str(plot_path), format='png', width=1400, height=1000, scale=2) | |
| print(f"Static PNG plot saved: {plot_path}") | |
| return str(plot_path) | |
| except Exception as e: | |
| print(f"Warning: Could not save PNG: {e}") | |
| return None | |
| def list_available_themes(self): | |
| """List available color themes""" | |
| return COLOR_THEMES | |
| def test_interactive_plot_generator(): | |
| """Test function for the interactive plot generator""" | |
| print("Testing interactive plot generator...") | |
| # Create test data | |
| lats = np.linspace(6, 38, 50) | |
| lons = np.linspace(68, 98, 60) | |
| lon_grid, lat_grid = np.meshgrid(lons, lats) | |
| data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50 | |
| data += np.random.normal(0, 10, data.shape) | |
| metadata = { | |
| 'variable_name': 'pm25', | |
| 'display_name': 'PM2.5', | |
| 'units': 'µg/m³', | |
| 'lats': lats, | |
| 'lons': lons, | |
| 'pressure_level': None, | |
| 'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
| } | |
| shapefile_path = "shapefiles/India_State_Boundary.shp" | |
| if not Path(shapefile_path).exists(): | |
| print(f"❌ Test failed: Shapefile not found at '{shapefile_path}'.") | |
| print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.") | |
| return False | |
| plotter = InteractiveIndiaMapPlotter(shapefile_path=shapefile_path) | |
| try: | |
| result = plotter.create_india_map(data, metadata, color_theme='YlOrRd') | |
| if result.get('html_path'): | |
| print(f"✅ Test interactive HTML plot created successfully: {result['html_path']}") | |
| if result.get('png_path'): | |
| print(f"✅ Test static PNG plot created successfully: {result['png_path']}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Test failed: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_color_themes(): | |
| """Test all available color themes for compatibility""" | |
| from constants import COLOR_THEMES | |
| # Create colormap mapping | |
| colormap_mapping = { | |
| # Sequential color schemes | |
| 'viridis': 'Viridis', | |
| 'plasma': 'Plasma', | |
| 'inferno': 'Inferno', | |
| 'magma': 'Magma', | |
| 'cividis': 'Cividis', | |
| # Single-hue sequential schemes | |
| 'YlOrRd': 'YlOrRd', | |
| 'Oranges': 'Oranges', | |
| 'Reds': 'Reds', | |
| 'Purples': 'Purples', | |
| 'Blues': 'Blues', | |
| 'Greens': 'Greens', | |
| # Diverging schemes | |
| 'coolwarm': 'RdBu_r', | |
| 'RdYlBu': 'RdYlBu', | |
| 'Spectral': 'Spectral', | |
| 'Spectral_r': 'Spectral_r', | |
| 'RdYlGn_r': 'RdYlGn_r', | |
| # Other schemes | |
| 'jet': 'Jet', | |
| 'turbo': 'Turbo' | |
| } | |
| print("🎨 Testing color theme mappings:") | |
| print(f"{'Color Theme':<15} {'Plotly Colorscale':<20} {'Status'}") | |
| print("-" * 50) | |
| for theme_key in COLOR_THEMES.keys(): | |
| if theme_key in colormap_mapping: | |
| plotly_scale = colormap_mapping[theme_key] | |
| status = "✅ Mapped" | |
| else: | |
| plotly_scale = "Viridis (default)" | |
| status = "⚠️ Missing" | |
| print(f"{theme_key:<15} {plotly_scale:<20} {status}") | |
| missing_themes = set(COLOR_THEMES.keys()) - set(colormap_mapping.keys()) | |
| if missing_themes: | |
| print(f"\n❌ Missing mappings for: {', '.join(missing_themes)}") | |
| return False | |
| else: | |
| print(f"\n✅ All {len(COLOR_THEMES)} color themes are properly mapped!") | |
| return True | |
| if __name__ == "__main__": | |
| test_interactive_plot_generator() |