Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """WorkingProto_10_9.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1iolwhZ4aRpITScZ_xn4Tuo6CuTckfLVa | |
| """ | |
| # ============================================================================== | |
| # CELL 1: SETUP AND CONSOLIDATED IMPORTS | |
| # ============================================================================== | |
| import gradio as gr | |
| import os | |
| import json | |
| import uuid | |
| import shutil | |
| import zipfile | |
| import pathlib | |
| import tempfile | |
| import pandas as pd | |
| import PIL.Image | |
| from datetime import datetime | |
| import huggingface_hub | |
| import autogluon.multimodal | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| import matplotlib.colors | |
| import folium | |
| from scipy.stats import gaussian_kde | |
| from datasets import load_dataset | |
| from geopy.geocoders import Nominatim | |
| from geopy.extra.rate_limiter import RateLimiter | |
| # ============================================================================== | |
| # CELL 2: CORE LOGIC FOR TAB 1 (UNCHANGED) | |
| # ============================================================================== | |
| # --- Functions for Data Capture --- | |
| def get_current_time(): | |
| return datetime.now().isoformat() | |
| def handle_time_capture(): | |
| timestamp = get_current_time() | |
| status_msg = f"π **Time Captured**: {timestamp}" | |
| return status_msg, timestamp | |
| def get_gps_js(): | |
| return """ | |
| () => { | |
| if (!navigator.geolocation) { alert("Geolocation not supported"); return; } | |
| navigator.geolocation.getCurrentPosition( | |
| function(position) { | |
| const latBox = document.querySelector('#lat textarea'); | |
| const lonBox = document.querySelector('#lon textarea'); | |
| const accuracyBox = document.querySelector('#accuracy textarea'); | |
| const timestampBox = document.querySelector('#device_ts textarea'); | |
| if (latBox && lonBox && accuracyBox && timestampBox) { | |
| latBox.value = position.coords.latitude.toString(); | |
| lonBox.value = position.coords.longitude.toString(); | |
| accuracyBox.value = position.coords.accuracy.toString(); | |
| timestampBox.value = new Date().toISOString(); | |
| latBox.dispatchEvent(new Event('input', { bubbles: true })); | |
| lonBox.dispatchEvent(new Event('input', { bubbles: true })); | |
| accuracyBox.dispatchEvent(new Event('input', { bubbles: true })); | |
| timestampBox.dispatchEvent(new Event('input', { bubbles: true })); | |
| } else { alert("Error: Could not find GPS input fields"); } | |
| }, | |
| function(err) { alert("GPS Error: " + err.message); }, | |
| { enableHighAccuracy: true, timeout: 10000 } | |
| ); | |
| } | |
| """ | |
| def save_to_dataset(image, lat, lon, accuracy_m, device_ts): | |
| if image is None: | |
| return "β **Error**: Please capture or upload a photo first.", "" | |
| mock_data = { | |
| "image": "image.jpg", "latitude": lat, "longitude": lon, | |
| "accuracy_m": accuracy_m, "device_timestamp": device_ts, | |
| "status": "Saving Disabled" | |
| } | |
| status = "β **Test Save Successful!** (No data saved)" | |
| return status, json.dumps(mock_data, indent=2) | |
| placeholder_time_capture = handle_time_capture | |
| placeholder_save_action = save_to_dataset | |
| # --- Functions for Model Prediction --- | |
| MODEL_REPO_ID = "ddecosmo/lanternfly_classifier" | |
| ZIP_FILENAME = "autogluon_image_predictor_dir.zip" | |
| CLASS_LABELS = {0: "Lanternfly", 1: "Other Insect", 2: "No Insect"} | |
| CACHE_DIR = pathlib.Path("hf_assets") | |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" | |
| PREDICTOR = None | |
| def _prepare_predictor_dir(): | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| token = os.getenv("HF_TOKEN", None) | |
| local_zip = huggingface_hub.hf_hub_download( | |
| repo_id=MODEL_REPO_ID, filename=ZIP_FILENAME, repo_type="model", | |
| token=token, local_dir=str(CACHE_DIR), local_dir_use_symlinks=False, | |
| ) | |
| if EXTRACT_DIR.exists(): shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(local_zip, "r") as zf: zf.extractall(str(EXTRACT_DIR)) | |
| contents = list(EXTRACT_DIR.iterdir()) | |
| return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR) | |
| try: | |
| PREDICTOR_DIR = _prepare_predictor_dir() | |
| PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR) | |
| PREDICTOR_LOAD_STATUS = "β AutoGluon Predictor loaded successfully." | |
| print(PREDICTOR_LOAD_STATUS) | |
| except Exception as e: | |
| PREDICTOR_LOAD_STATUS = f"β Failed to load AutoGluon Predictor: {e}" | |
| print(PREDICTOR_LOAD_STATUS) | |
| PREDICTOR = None | |
| def do_predict(pil_img: PIL.Image.Image): | |
| if PREDICTOR is None: return {"Error": 1.0}, "Model not loaded.", "" | |
| if pil_img is None: return {"No Image": 1.0}, "No image provided.", "" | |
| tmpdir = pathlib.Path(tempfile.mkdtemp()) | |
| img_path = tmpdir / "input.png" | |
| pil_img.save(img_path) | |
| df = pd.DataFrame({"image": [str(img_path)]}) | |
| proba_df = PREDICTOR.predict_proba(df).rename(columns=CLASS_LABELS) | |
| row = proba_df.iloc[0] | |
| pretty_dict = {label: float(row.get(label, 0.0)) for label in CLASS_LABELS.values()} | |
| confidence_info = ", ".join([f"{label}: {prob:.2f}" for label, prob in pretty_dict.items()]) | |
| return pretty_dict, confidence_info | |
| # ============================================================================== | |
| # CELL 3: CORE LOGIC FOR TAB 2 (KDE ANALYSIS) | |
| # ============================================================================== | |
| pittsburgh_lat_min = 40.43950159029883 | |
| pittsburgh_lat_max = 40.44787067820301 | |
| pittsburgh_lon_min = -79.95054304624013 | |
| pittsburgh_lon_max = -79.93588847945053 | |
| def load_dataframe_from_huggingface(): | |
| try: | |
| print("Loading data directly from Hugging Face dataset...") | |
| dataset = load_dataset("rlogh/lanternfly-data", data_files="metadata/entries.jsonl", split="train") | |
| df = dataset.to_pandas() | |
| print("β Data successfully loaded into a DataFrame.") | |
| return df | |
| except Exception as e: | |
| print(f"β Error loading data from Hugging Face: {e}") | |
| return None | |
| def calculate_kde_from_dataframe(df): | |
| try: | |
| if 'latitude' not in df.columns or 'longitude' not in df.columns: | |
| return None, None, None, "Error: DataFrame must contain 'latitude' and 'longitude' columns." | |
| df.dropna(subset=['latitude', 'longitude'], inplace=True) | |
| latitudes = df['latitude'].values | |
| longitudes = df['longitude'].values | |
| coordinates = np.vstack([longitudes, latitudes]) | |
| kde_object = gaussian_kde(coordinates) | |
| return latitudes, longitudes, kde_object, None | |
| except Exception as e: | |
| return None, None, None, f"Error calculating KDE from DataFrame: {e}" | |
| import math | |
| def find_hotspot_landmark(original_latitudes, original_longitudes, kde_object): | |
| """ | |
| Finds the hotspot and identifies the closest landmark from a predefined | |
| custom list of campus locations. | |
| """ | |
| # 1. Create your own dictionary of important campus landmarks | |
| CAMPUS_LANDMARKS = { | |
| "Scaife Hall": (40.441742986804336, -79.94725195600002), | |
| "Hunt Library": (40.44097574857165, -79.94362666281333), | |
| "Cohon University Center": (40.44401378993309, -79.94172335009584), | |
| "Gates Hillman Complex": (40.4436463605335, -79.94442701667683), | |
| "Wean Hall": (40.44267896399903, -79.94582169457243), | |
| "Gesling Stadium": (40.443038206822905, -79.94038027450188), | |
| "The Fence": (40.44221744932438, -79.9435687098247) | |
| } | |
| # 2. Find the coordinates of the densest point (same as before) | |
| all_coords = np.vstack([original_longitudes, original_latitudes]) | |
| densities = kde_object(all_coords) | |
| hotspot_index = np.argmax(densities) | |
| hotspot_lat = original_latitudes[hotspot_index] | |
| hotspot_lon = original_longitudes[hotspot_index] | |
| # 3. Function to calculate the distance between two coordinates | |
| def distance(lat1, lon1, lat2, lon2): | |
| # A simple Euclidean distance is good enough for a small area like a campus | |
| return math.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2) | |
| # 4. Find the landmark from your list with the smallest distance to the hotspot | |
| closest_landmark = min( | |
| CAMPUS_LANDMARKS.keys(), | |
| key=lambda landmark: distance(hotspot_lat, hotspot_lon, CAMPUS_LANDMARKS[landmark][0], CAMPUS_LANDMARKS[landmark][1]) | |
| ) | |
| return f"π **Hotspot Analysis**: The highest concentration was found closest to **{closest_landmark}** on campus." | |
| def plot_kde_and_points_for_gradio(min_lat, max_lat, min_lon, max_lon, original_latitudes, original_longitudes, kde_object): | |
| heatmap_path = "lanternfly_kde_heatmap.png" | |
| x, y = np.mgrid[min_lon:max_lon:100j, min_lat:max_lat:100j] | |
| positions = np.vstack([x.ravel(), y.ravel()]) | |
| z = kde_object(positions).reshape(x.shape) | |
| z_normalized = (z - z.min()) / (z.max() - z.min()) if z.max() > z.min() else np.zeros_like(z) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| im = ax.imshow(z_normalized.T, origin='lower', extent=[min_lon, max_lon, min_lat, max_lat], cmap='hot', aspect='auto') | |
| fig.colorbar(im, ax=ax, label='Normalized Density (0-1)') | |
| ax.set_title('Lanternfly Sightings KDE Heatmap (Static)') | |
| plt.savefig(heatmap_path, bbox_inches='tight') | |
| plt.close(fig) | |
| m_colored_points = folium.Map() | |
| bounds = [[min_lat, min_lon], [max_lat, max_lon]] | |
| m_colored_points.fit_bounds(bounds) | |
| original_coordinates = np.vstack([original_longitudes, original_latitudes]) | |
| density_at_points = kde_object(original_coordinates) | |
| density_normalized_for_color = (density_at_points - density_at_points.min()) / (density_at_points.max() - density_at_points.min() + 1e-9) | |
| max_density = density_at_points.max() | |
| colormap = cm.get_cmap('viridis') | |
| for lat, lon, density_norm_color in zip(original_latitudes, original_longitudes, density_normalized_for_color): | |
| if min_lat <= lat <= max_lat and min_lon <= lon <= max_lon: | |
| color = matplotlib.colors.rgb2hex(colormap(density_norm_color)) | |
| raw_density = kde_object([lon, lat])[0] | |
| normalized_tooltip_density = raw_density / max_density if max_density > 0 else 0 | |
| folium.CircleMarker( | |
| location=[lat, lon], radius=5, color=color, fill=True, | |
| fill_color=color, fill_opacity=0.7, | |
| tooltip=f"Normalized Density: {normalized_tooltip_density:.4f}" | |
| ).add_to(m_colored_points) | |
| return heatmap_path, m_colored_points._repr_html_() | |
| def run_full_analysis_and_update_ui(): | |
| print("Loading data...") | |
| lanternfly_df = load_dataframe_from_huggingface() | |
| if lanternfly_df is None: | |
| return gr.Image(visible=False), gr.HTML("<h3>Error...</h3>", visible=True), gr.Markdown(visible=False) | |
| print("Calculating KDE...") | |
| latitudes, longitudes, kde_object, error = calculate_kde_from_dataframe(lanternfly_df) | |
| if error: | |
| return gr.Image(visible=False), gr.HTML(f"<h3>Error...</h3>", visible=True), gr.Markdown(visible=False) | |
| print("Generating visualizations...") | |
| heatmap_path, interactive_map_html = plot_kde_and_points_for_gradio( | |
| pittsburgh_lat_min, pittsburgh_lat_max, | |
| pittsburgh_lon_min, pittsburgh_lon_max, | |
| latitudes, longitudes, kde_object | |
| ) | |
| print("Finding hotspot landmark...") | |
| hotspot_message = find_hotspot_landmark(latitudes, longitudes, kde_object) | |
| return ( | |
| gr.Image(value=heatmap_path, visible=True), | |
| gr.HTML(value=interactive_map_html, visible=True), | |
| gr.Markdown(value=hotspot_message, visible=True) | |
| ) | |
| # ============================================================================== | |
| # CELL 4: GRADIO UI DEFINITIONS | |
| # ============================================================================== | |
| def field_capture_ui(camera): | |
| with gr.Blocks(): | |
| gr.Markdown("#Lanternfly Data Logging") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Location Data") | |
| gps_btn = gr.Button("π Get GPS", variant="primary") | |
| with gr.Row(): | |
| lat_box = gr.Textbox(label="Latitude", interactive=True, value="0.0", elem_id="lat") | |
| lon_box = gr.Textbox(label="Longitude", interactive=True, value="0.0", elem_id="lon") | |
| with gr.Row(): | |
| accuracy_box = gr.Textbox(label="Accuracy (meters)", interactive=True, value="0.0", elem_id="accuracy") | |
| device_ts_box = gr.Textbox(label="Device Timestamp", interactive=True, elem_id="device_ts") | |
| time_btn = gr.Button("π Get Current Time") | |
| save_btn = gr.Button("πΎ Save (Test Mode)") | |
| status = gr.Markdown("π **Ready**") | |
| preview = gr.JSON(label="Preview JSON") | |
| gps_btn.click(fn=None, inputs=[], outputs=[], js=get_gps_js()) | |
| time_btn.click(fn=placeholder_time_capture, inputs=[], outputs=[status, device_ts_box]) | |
| save_btn.click(fn=placeholder_save_action, inputs=[camera, lat_box, lon_box, accuracy_box, device_ts_box], outputs=[status, preview]) | |
| def image_model_ui(image_in): | |
| with gr.Blocks(): | |
| gr.Markdown("# Image Classification Results") | |
| gr.Markdown("Uses an EfficientNetB1 model to classify the uploaded image.") | |
| if PREDICTOR is None: | |
| gr.Warning(PREDICTOR_LOAD_STATUS) | |
| with gr.Row(): | |
| proba_pretty = gr.Label(num_top_classes=2, label="Class Probabilities") | |
| confidence_output = gr.Textbox(label="Prediction Summary") | |
| # Attach prediction logic to the passed-in image component | |
| image_in.change( | |
| fn=do_predict, | |
| inputs=[image_in], | |
| outputs=[proba_pretty, confidence_output] | |
| ) | |
| # ** NEW / UPDATED **: Add the example images section here | |
| # This assumes you have an 'examples' folder with these images in it. | |
| gr.Examples( | |
| examples=[ | |
| "examples/lanternfly_example.jpg", | |
| "examples/other_insect_example.jpg", | |
| "examples/no_insect_example.jpg" | |
| ], | |
| inputs=[image_in], | |
| label="Click an Example to Classify", | |
| examples_per_page=3 | |
| ) | |
| def kde_analysis_ui(): | |
| """ | |
| Renders the complete UI for the KDE tab with the controls on top | |
| and the outputs below. | |
| """ | |
| # --- 1. UI Controls (These will appear on top) --- | |
| gr.Markdown("# Spotted Lanternfly Kernel Density Estimation Analysis") | |
| gr.Markdown("Click the button to generate a Kernel Density Estimation (KDE) analysis based on the data gathered from the classification tab.") | |
| gr.Markdown("This data can be found at rlogh/lanternfly-data on Hugging Face and contains images, geolocal, and temporal data for all samples.") | |
| gr.Markdown("This dataset is public and available for use for any research or learning purposes.") | |
| btn = gr.Button("Generate KDE Visualizations") | |
| # --- 2. Output Areas (These will appear below the button) --- | |
| message_output = gr.Markdown(visible=False) | |
| with gr.Row(): | |
| heatmap_output = gr.Image(label="KDE Heatmap (Static)", visible=False) | |
| map_output = gr.HTML(label="Interactive Density Map", visible=False) | |
| # --- 3. Link the Button to the Function and Outputs --- | |
| btn.click( | |
| fn=run_full_analysis_and_update_ui, | |
| inputs=None, | |
| outputs=[heatmap_output, map_output, message_output] | |
| ) | |
| with gr.Blocks(title="Unified Lanternfly App") as app: | |
| gr.Markdown("# Lanternfly Tracker") | |
| gr.Markdown("This application allows for the tracking of concentrated lanternflies, mainly around Carnegie Mellon University.") | |
| gr.Markdown("It combines two tools: (1) A field capture and AI Image classifer for identifying lanternflies, and (2) a Kernel Density Estimation (KDE) ML model to visualize lanternfly hotspots on campus.") | |
| gr.Markdown("Photos can be taken and classified as Lanternflies in the Capture & Classification tab. In future this data can be saved in real time to the dataset") | |
| gr.Markdown("To view the overal distribution of lanternflies based on collected data, use the Spatial Analysis (KDE) tab.") | |
| # TAB 1: (Unchanged) | |
| with gr.Tab("Capture & Classification"): | |
| gr.Info("GPS functionality is now enabled! Data saving is in test mode.") | |
| shared_image_input = gr.Image( | |
| streaming=False, height=380, label="π· Upload Photo (or use camera)", | |
| type="pil", sources=["webcam", "upload"] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_model_ui(shared_image_input) | |
| with gr.Column(scale=1): | |
| field_capture_ui(shared_image_input) | |
| # TAB 2: KDE ANALYSIS (Simplified and Corrected) | |
| with gr.Tab("Spatial Analysis (KDE)"): | |
| # This single function call now builds the entire tab correctly. | |
| kde_analysis_ui() | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() |