import gradio as gr import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Circle, FancyBboxPatch, ConnectionPatch import io from PIL import Image from matplotlib.patches import FancyArrowPatch from scipy.spatial.distance import euclidean, cityblock, chebyshev class KNNVisualizer: def __init__(self): self.reset_data() self.test_point = None def reset_data(self): """Generate sample data with three classes""" np.random.seed(42) # Class 0 (blue) - bottom left n_samples = 30 self.X0 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 3]) # Class 1 (red) - top right self.X1 = np.random.randn(n_samples, 2) * 1.2 + np.array([7, 7]) # Class 2 (green) - top left self.X2 = np.random.randn(n_samples, 2) * 1.2 + np.array([3, 7]) self.X = np.vstack([self.X0, self.X1, self.X2]) self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples), np.full(n_samples, 2)]) self.test_point = np.array([5.0, 5.0]) def calculate_distance(self, point1, point2, metric='euclidean'): """Calculate distance between two points using specified metric""" if metric == 'euclidean': return euclidean(point1, point2) elif metric == 'manhattan': return cityblock(point1, point2) elif metric == 'chebyshev': return chebyshev(point1, point2) else: return euclidean(point1, point2) def find_k_nearest_neighbors(self, test_point, k, metric='euclidean'): """Find k nearest neighbors to the test point""" distances = [] for i, point in enumerate(self.X): dist = self.calculate_distance(test_point, point, metric) distances.append((i, dist, self.y[i])) # Sort by distance distances.sort(key=lambda x: x[1]) # Return k nearest neighbors return distances[:k] def predict_class(self, neighbors): """Predict class based on majority vote from neighbors""" classes = [int(n[2]) for n in neighbors] class_counts = np.bincount(classes) return np.argmax(class_counts), class_counts def visualize(self, test_x, test_y, k_value, distance_metric, show_all_distances): """Create comprehensive KNN visualization""" fig = plt.figure(figsize=(20, 12)) gs = fig.add_gridspec(2, 2, height_ratios=[1.2, 1], width_ratios=[1.5, 1]) ax1 = fig.add_subplot(gs[0, 0]) # Main KNN visualization ax2 = fig.add_subplot(gs[1, 0]) # Distance calculations table ax3 = fig.add_subplot(gs[:, 1]) # Statistics and breakdown # Parse inputs try: test_point = np.array([float(test_x), float(test_y)]) k = int(k_value) k = max(1, min(k, len(self.X))) # Ensure k is valid except: test_point = np.array([5.0, 5.0]) k = 5 # Find k nearest neighbors neighbors = self.find_k_nearest_neighbors(test_point, k, distance_metric) predicted_class, class_counts = self.predict_class(neighbors) # === Plot 1: Main KNN Visualization === ax1.set_facecolor('#f0f0f0') # Define colors for classes class_colors = ['blue', 'red', 'green'] class_names = ['Class 0 (Blue)', 'Class 1 (Red)', 'Class 2 (Green)'] # Plot all training points for class_idx in range(3): mask = self.y == class_idx ax1.scatter(self.X[mask, 0], self.X[mask, 1], c=class_colors[class_idx], label=class_names[class_idx], s=100, alpha=0.6, edgecolors='k', linewidths=1.5) # Highlight k nearest neighbors with larger markers neighbor_indices = [n[0] for n in neighbors] neighbor_distances = [n[1] for n in neighbors] for idx, (n_idx, dist, n_class) in enumerate(neighbors): point = self.X[n_idx] # Draw circle around neighbor circle = Circle(point, 0.3, color=class_colors[int(n_class)], fill=False, linewidth=3, linestyle='--', alpha=0.8) ax1.add_patch(circle) # Draw line from test point to neighbor if show_all_distances or idx < 10: # Show lines for top 10 or all if selected ax1.plot([test_point[0], point[0]], [test_point[1], point[1]], 'k--', alpha=0.3, linewidth=1) # Add distance label mid_x = (test_point[0] + point[0]) / 2 mid_y = (test_point[1] + point[1]) / 2 ax1.text(mid_x, mid_y, f'{dist:.2f}', fontsize=8, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7)) # Plot test point with star marker ax1.scatter(test_point[0], test_point[1], c=class_colors[predicted_class], marker='*', s=800, edgecolors='black', linewidths=3, label=f'Test Point (Predicted: Class {predicted_class})', zorder=100) # Draw decision boundary circle (radius = distance to k-th neighbor) max_neighbor_dist = neighbors[-1][1] boundary_circle = Circle(test_point, max_neighbor_dist, color='purple', fill=False, linewidth=2.5, linestyle=':', alpha=0.6, label=f'Decision Boundary (r={max_neighbor_dist:.2f})') ax1.add_patch(boundary_circle) # Add grid and labels ax1.grid(True, alpha=0.3, linestyle='--', linewidth=0.5) ax1.set_xlabel('Feature 1 (X)', fontsize=14, fontweight='bold') ax1.set_ylabel('Feature 2 (Y)', fontsize=14, fontweight='bold') ax1.set_title(f'K-Nearest Neighbors (k={k}, metric={distance_metric})', fontsize=16, fontweight='bold') ax1.legend(fontsize=10, loc='upper left', framealpha=0.9) ax1.set_xlim(-1, 11) ax1.set_ylim(-1, 11) # === Plot 2: Distance Calculations Table === ax2.axis('off') # Prepare table data table_data = [] table_data.append(['Rank', 'Index', 'X', 'Y', 'Class', 'Distance', 'Neighbor?']) # Calculate all distances for comparison all_distances = [] for i, point in enumerate(self.X): dist = self.calculate_distance(test_point, point, distance_metric) all_distances.append((i, dist, self.y[i])) all_distances.sort(key=lambda x: x[1]) # Show top 15 closest points display_count = min(15, len(all_distances)) for rank, (idx, dist, point_class) in enumerate(all_distances[:display_count], 1): point = self.X[idx] is_neighbor = '✓' if rank <= k else '' row = [ f'{rank}', f'{idx}', f'{point[0]:.2f}', f'{point[1]:.2f}', f'{int(point_class)}', f'{dist:.3f}', is_neighbor ] table_data.append(row) # Create table table = ax2.table(cellText=table_data, cellLoc='center', loc='center', bbox=[0, 0, 1, 1]) table.auto_set_font_size(False) table.set_fontsize(9) table.scale(1, 2) # Style header row for i in range(7): cell = table[(0, i)] cell.set_facecolor('#4CAF50') cell.set_text_props(weight='bold', color='white') # Style data rows for i in range(1, len(table_data)): # Highlight neighbors if i <= k: for j in range(7): table[(i, j)].set_facecolor('#E8F5E9') # Color code by class class_col = int(table_data[i][4]) table[(i, 4)].set_facecolor(class_colors[class_col]) table[(i, 4)].set_alpha(0.3) ax2.set_title('Distance Calculations (Sorted by Distance)', fontsize=14, fontweight='bold', pad=20) # === Plot 3: Statistics and Algorithm Breakdown === ax3.axis('off') stats_text = "K-NEAREST NEIGHBORS ALGORITHM\n" stats_text += "="*60 + "\n\n" stats_text += f"TEST POINT COORDINATES:\n" stats_text += f" • X: {test_point[0]:.2f}\n" stats_text += f" • Y: {test_point[1]:.2f}\n\n" stats_text += f"ALGORITHM PARAMETERS:\n" stats_text += f" • K value: {k}\n" stats_text += f" • Distance metric: {distance_metric.upper()}\n" stats_text += f" • Total training samples: {len(self.X)}\n\n" # Distance metric explanation stats_text += f"DISTANCE METRIC: {distance_metric.upper()}\n" stats_text += "-"*60 + "\n" if distance_metric == 'euclidean': stats_text += "Formula: d = √[(x₂-x₁)² + (y₂-y₁)²]\n" stats_text += " • Standard straight-line distance\n" stats_text += " • Most commonly used metric\n" elif distance_metric == 'manhattan': stats_text += "Formula: d = |x₂-x₁| + |y₂-y₁|\n" stats_text += " • Also called 'City Block' distance\n" stats_text += " • Sum of absolute differences\n" elif distance_metric == 'chebyshev': stats_text += "Formula: d = max(|x₂-x₁|, |y₂-y₁|)\n" stats_text += " • Maximum absolute difference\n" stats_text += " • Chess king's move distance\n" stats_text += "\n" stats_text += f"K NEAREST NEIGHBORS FOUND:\n" stats_text += "-"*60 + "\n" for rank, (idx, dist, point_class) in enumerate(neighbors, 1): point = self.X[idx] stats_text += f"\n{rank}. Point #{idx} (Class {int(point_class)})\n" stats_text += f" Position: ({point[0]:.2f}, {point[1]:.2f})\n" stats_text += f" Distance: {dist:.4f}\n" # Show calculation for first 3 neighbors if rank <= 3: if distance_metric == 'euclidean': dx = point[0] - test_point[0] dy = point[1] - test_point[1] stats_text += f" Calculation: √[({dx:.2f})² + ({dy:.2f})²]\n" stats_text += f" = √[{dx**2:.2f} + {dy**2:.2f}]\n" stats_text += f" = {dist:.4f}\n" elif distance_metric == 'manhattan': dx = abs(point[0] - test_point[0]) dy = abs(point[1] - test_point[1]) stats_text += f" Calculation: |{dx:.2f}| + |{dy:.2f}|\n" stats_text += f" = {dist:.4f}\n" stats_text += "\n\nCLASS DISTRIBUTION IN K NEIGHBORS:\n" stats_text += "-"*60 + "\n" for class_idx in range(3): count = class_counts[class_idx] if class_idx < len(class_counts) else 0 percentage = (count / k) * 100 bar = '█' * int(percentage / 5) stats_text += f"Class {class_idx}: {count}/{k} ({percentage:.1f}%) {bar}\n" stats_text += f"\n\nPREDICTION RESULT:\n" stats_text += "="*60 + "\n" stats_text += f" → Predicted Class: {predicted_class}\n" stats_text += f" → Confidence: {class_counts[predicted_class]}/{k} neighbors\n" stats_text += f" → Percentage: {(class_counts[predicted_class]/k)*100:.1f}%\n\n" stats_text += "ALGORITHM STEPS:\n" stats_text += "-"*60 + "\n" stats_text += "1. Calculate distance from test point to all\n" stats_text += " training points using selected metric\n" stats_text += "2. Sort all points by distance (ascending)\n" stats_text += "3. Select the K nearest points\n" stats_text += "4. Count class labels among K neighbors\n" stats_text += "5. Predict class with majority vote\n" ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes, fontsize=9, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8), family='monospace') plt.tight_layout() # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close() return img # Create the visualizer knn_viz = KNNVisualizer() # Create Gradio interface with gr.Blocks(title="K-Nearest Neighbors (KNN) Visualizer", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 Interactive K-Nearest Neighbors (KNN) Algorithm Visualizer Explore how KNN algorithm works by visualizing distance calculations and neighbor identification! **Instructions:** 1. Set the test point coordinates (X, Y) 2. Choose the number of neighbors (K) 3. Select a distance metric 4. Click "Update Visualization" to see the results """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Test Point Configuration") test_x = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1, label="Test Point X Coordinate") test_y = gr.Slider(minimum=-1, maximum=11, value=5.0, step=0.1, label="Test Point Y Coordinate") gr.Markdown("### KNN Parameters") k_value = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="K (Number of Neighbors)") distance_metric = gr.Radio( choices=['euclidean', 'manhattan', 'chebyshev'], value='euclidean', label="Distance Metric" ) show_all_distances = gr.Checkbox( value=False, label="Show all distance lines (may be cluttered)" ) update_btn = gr.Button("🔄 Update Visualization", variant="primary", size="lg") gr.Markdown(""" ### Distance Metrics: - **Euclidean**: Standard straight-line distance - **Manhattan**: Sum of absolute differences (city block) - **Chebyshev**: Maximum absolute difference ### Try These Examples: - **Test Point (5, 5), K=5**: See balanced classification - **Test Point (2, 2), K=3**: Point near Class 0 - **Test Point (8, 8), K=7**: Point near Class 1 - **Different K values**: See how it affects prediction """) with gr.Column(scale=2): output_image = gr.Image(label="KNN Visualization", height=900) # Update visualization update_btn.click( fn=knn_viz.visualize, inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], outputs=output_image ) # Also update on slider/radio change for component in [test_x, test_y, k_value, distance_metric, show_all_distances]: component.change( fn=knn_viz.visualize, inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], outputs=output_image ) # Initial visualization demo.load( fn=knn_viz.visualize, inputs=[test_x, test_y, k_value, distance_metric, show_all_distances], outputs=output_image ) # Launch the app if __name__ == "__main__": demo.launch()