{ "cells": [ { "cell_type": "markdown", "id": "164d7e04", "metadata": {}, "source": [ "# Multi-Split Decision Tree Visualizer\n", "\n", "This notebook creates an interactive Gradio app to visualize how decision trees partition the feature space with **multiple splits** and shows the complete **decision tree structure**.\n", "\n", "## ✨ New Features:\n", "- **Multiple Partitions**: Add as many splits as you want to build a complete tree\n", "- **Decision Tree Visualization**: See the tree structure with all nodes and connections\n", "- **Interactive Split Entry**: Add splits in a simple text format (feature, threshold)\n", "- **Comprehensive Statistics**: Track entropy and Gini index for each node and leaf\n", "- **Color-coded Visualization**: \n", " - Blue arrows = \"Yes\" branch (≤ threshold)\n", " - Red arrows = \"No\" branch (> threshold)\n", " - Light blue leaves = Predicts Class 0 (Lemon)\n", " - Orange leaves = Predicts Class 1 (Orange)\n", "\n", "## 📊 Three-Panel Display:\n", "1. **Top-Left**: Partitioned feature space with all split boundaries\n", "2. **Bottom-Left**: Complete decision tree structure\n", "3. **Right**: Detailed statistics and impurity measures" ] }, { "cell_type": "code", "execution_count": 1, "id": "8b654a81", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\rinab\\miniforge3\\envs\\WORK\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7860\n", "* Running on public URL: https://4d58db9d9d6f8c53bc.gradio.live\n", "\n", "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n", "* Running on public URL: https://4d58db9d9d6f8c53bc.gradio.live\n", "\n", "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Rectangle, FancyBboxPatch\n", "import io\n", "from PIL import Image\n", "from matplotlib.patches import FancyArrowPatch\n", "\n", "class TreeNode:\n", " \"\"\"Represents a node in the decision tree\"\"\"\n", " def __init__(self, depth=0, bounds=None):\n", " self.depth = depth\n", " self.bounds = bounds if bounds else {'x': (0, 10), 'y': (0, 10)}\n", " self.feature = None # 'x' or 'y'\n", " self.threshold = None\n", " self.left = None\n", " self.right = None\n", " self.is_leaf = True\n", " self.samples = None\n", " self.class_counts = None\n", " self.entropy = None\n", " self.gini = None\n", " self.majority_class = None\n", " \n", "class DecisionTreePartitioner:\n", " def __init__(self):\n", " self.reset_data()\n", " self.splits = [] # List of (feature, threshold) tuples\n", " self.root = None\n", " \n", " def reset_data(self):\n", " \"\"\"Generate sample data with two classes\"\"\"\n", " np.random.seed(42)\n", " # Class 0 (blue) - bottom left\n", " n_samples = 50\n", " self.X0 = np.random.randn(n_samples, 2) * 1.5 + np.array([3, 3])\n", " # Class 1 (red) - top right \n", " self.X1 = np.random.randn(n_samples, 2) * 1.5 + np.array([7, 7])\n", " \n", " self.X = np.vstack([self.X0, self.X1])\n", " self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples)])\n", " self.splits = []\n", " self.root = None\n", " \n", " def calculate_entropy(self, y):\n", " \"\"\"Calculate entropy for a set of labels\"\"\"\n", " if len(y) == 0:\n", " return 0\n", " _, counts = np.unique(y, return_counts=True)\n", " probabilities = counts / len(y)\n", " entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10))\n", " return entropy\n", " \n", " def calculate_gini(self, y):\n", " \"\"\"Calculate Gini index for a set of labels\"\"\"\n", " if len(y) == 0:\n", " return 0\n", " _, counts = np.unique(y, return_counts=True)\n", " probabilities = counts / len(y)\n", " gini = 1 - np.sum(probabilities ** 2)\n", " return gini\n", " \n", " def build_tree_from_splits(self):\n", " \"\"\"Build tree structure from the list of splits\"\"\"\n", " if not self.splits:\n", " return None\n", " \n", " self.root = TreeNode(depth=0)\n", " self._build_node(self.root, np.arange(len(self.y)), 0)\n", " return self.root\n", " \n", " def _build_node(self, node, indices, split_idx):\n", " \"\"\"Recursively build tree nodes\"\"\"\n", " if len(indices) == 0:\n", " return\n", " \n", " # Calculate node statistics\n", " node.samples = len(indices)\n", " y_node = self.y[indices]\n", " unique, counts = np.unique(y_node, return_counts=True)\n", " node.class_counts = dict(zip(unique.astype(int), counts))\n", " node.entropy = self.calculate_entropy(y_node)\n", " node.gini = self.calculate_gini(y_node)\n", " node.majority_class = int(unique[np.argmax(counts)])\n", " \n", " # Check if we have more splits to apply\n", " if split_idx >= len(self.splits):\n", " node.is_leaf = True\n", " return\n", " \n", " # Apply the split\n", " feature, threshold = self.splits[split_idx]\n", " feature_idx = 0 if feature == 'x' else 1\n", " \n", " X_node = self.X[indices]\n", " left_mask = X_node[:, feature_idx] <= threshold\n", " right_mask = ~left_mask\n", " \n", " left_indices = indices[left_mask]\n", " right_indices = indices[right_mask]\n", " \n", " # Only create split if both children are non-empty\n", " if len(left_indices) > 0 and len(right_indices) > 0:\n", " node.is_leaf = False\n", " node.feature = feature\n", " node.threshold = threshold\n", " \n", " # Create child nodes with updated bounds\n", " left_bounds = node.bounds.copy()\n", " right_bounds = node.bounds.copy()\n", " \n", " if feature == 'x':\n", " left_bounds['x'] = (node.bounds['x'][0], threshold)\n", " right_bounds['x'] = (threshold, node.bounds['x'][1])\n", " else:\n", " left_bounds['y'] = (node.bounds['y'][0], threshold)\n", " right_bounds['y'] = (threshold, node.bounds['y'][1])\n", " \n", " node.left = TreeNode(depth=node.depth + 1, bounds=left_bounds)\n", " node.right = TreeNode(depth=node.depth + 1, bounds=right_bounds)\n", " \n", " # Recursively build children\n", " self._build_node(node.left, left_indices, split_idx + 1)\n", " self._build_node(node.right, right_indices, split_idx + 1)\n", " \n", " def add_split(self, feature, threshold):\n", " \"\"\"Add a new split to the tree\"\"\"\n", " self.splits.append((feature, threshold))\n", " self.build_tree_from_splits()\n", " \n", " def remove_last_split(self):\n", " \"\"\"Remove the last split\"\"\"\n", " if self.splits:\n", " self.splits.pop()\n", " if self.splits:\n", " self.build_tree_from_splits()\n", " else:\n", " self.root = None\n", " \n", " def draw_tree(self, node=None, ax=None, x=0.5, y=1.0, dx=0.25, level=0):\n", " \"\"\"Recursively draw the decision tree\"\"\"\n", " if node is None:\n", " return\n", " \n", " # Node styling\n", " if node.is_leaf:\n", " box_color = 'lightblue' if node.majority_class == 0 else 'orange'\n", " alpha = 0.7\n", " else:\n", " box_color = 'lightgreen'\n", " alpha = 0.5\n", " \n", " # Create node text\n", " if node.is_leaf:\n", " text = f\"Leaf\\nClass: {node.majority_class}\\n\"\n", " text += f\"Samples: {node.samples}\\n\"\n", " text += f\"Entropy: {node.entropy:.3f}\\n\"\n", " text += f\"Gini: {node.gini:.3f}\"\n", " else:\n", " feature_name = \"Width\" if node.feature == 'x' else \"Height\"\n", " text = f\"{feature_name} ≤ {node.threshold:.2f}\\n\"\n", " text += f\"Samples: {node.samples}\\n\"\n", " text += f\"Entropy: {node.entropy:.3f}\\n\"\n", " text += f\"Gini: {node.gini:.3f}\"\n", " \n", " # Draw box\n", " bbox = dict(boxstyle=\"round,pad=0.3\", facecolor=box_color, \n", " edgecolor='black', linewidth=2, alpha=alpha)\n", " ax.text(x, y, text, ha='center', va='center', fontsize=8,\n", " bbox=bbox, fontweight='bold')\n", " \n", " # Draw connections to children\n", " if not node.is_leaf and node.left and node.right:\n", " # Left child\n", " y_child = y - 0.15\n", " x_left = x - dx\n", " x_right = x + dx\n", " \n", " # Draw arrows\n", " arrow_left = FancyArrowPatch((x, y - 0.05), (x_left, y_child + 0.05),\n", " arrowstyle='->', mutation_scale=20, \n", " linewidth=2, color='blue')\n", " arrow_right = FancyArrowPatch((x, y - 0.05), (x_right, y_child + 0.05),\n", " arrowstyle='->', mutation_scale=20,\n", " linewidth=2, color='red')\n", " ax.add_patch(arrow_left)\n", " ax.add_patch(arrow_right)\n", " \n", " # Add Yes/No labels\n", " ax.text((x + x_left) / 2, (y + y_child) / 2, 'Yes', \n", " fontsize=9, color='blue', fontweight='bold')\n", " ax.text((x + x_right) / 2, (y + y_child) / 2, 'No',\n", " fontsize=9, color='red', fontweight='bold')\n", " \n", " # Recursively draw children\n", " self.draw_tree(node.left, ax, x_left, y_child, dx * 0.5, level + 1)\n", " self.draw_tree(node.right, ax, x_right, y_child, dx * 0.5, level + 1)\n", " \n", " def visualize(self, split_history):\n", " \"\"\"Create comprehensive visualization\"\"\"\n", " fig = plt.figure(figsize=(20, 10))\n", " gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1.2, 1])\n", " \n", " ax1 = fig.add_subplot(gs[0, 0]) # Partition view\n", " ax2 = fig.add_subplot(gs[1, 0]) # Decision tree\n", " ax3 = fig.add_subplot(gs[:, 1]) # Statistics\n", " \n", " # Parse split history\n", " self.splits = []\n", " if split_history.strip():\n", " for line in split_history.strip().split('\\n'):\n", " if ',' in line:\n", " parts = line.split(',')\n", " if len(parts) == 2:\n", " feature = parts[0].strip().lower()\n", " try:\n", " threshold = float(parts[1].strip())\n", " self.splits.append((feature, threshold))\n", " except ValueError:\n", " pass\n", " \n", " # Build tree from splits\n", " if self.splits:\n", " self.build_tree_from_splits()\n", " \n", " # === Plot 1: Partitioned Feature Space ===\n", " ax1.scatter(self.X[self.y == 0, 0], self.X[self.y == 0, 1], \n", " c='blue', label='Class 0 (Lemon)', s=100, alpha=0.6, edgecolors='k')\n", " ax1.scatter(self.X[self.y == 1, 0], self.X[self.y == 1, 1], \n", " c='orange', label='Class 1 (Orange)', s=100, alpha=0.6, edgecolors='k')\n", " \n", " # Draw all partition lines\n", " colors = plt.cm.rainbow(np.linspace(0, 1, len(self.splits)))\n", " for idx, (feature, threshold) in enumerate(self.splits):\n", " if feature == 'x':\n", " ax1.axvline(x=threshold, color=colors[idx], linewidth=2.5, \n", " linestyle='--', label=f'Split {idx+1}: x≤{threshold:.1f}', alpha=0.8)\n", " else:\n", " ax1.axhline(y=threshold, color=colors[idx], linewidth=2.5,\n", " linestyle='--', label=f'Split {idx+1}: y≤{threshold:.1f}', alpha=0.8)\n", " \n", " ax1.set_xlabel('Feature 1 (Width)', fontsize=14, fontweight='bold')\n", " ax1.set_ylabel('Feature 2 (Height)', fontsize=14, fontweight='bold')\n", " ax1.set_title('Partitioned Feature Space', fontsize=16, fontweight='bold')\n", " ax1.legend(fontsize=10, loc='upper left')\n", " ax1.grid(True, alpha=0.3)\n", " ax1.set_xlim(0, 10)\n", " ax1.set_ylim(0, 10)\n", " \n", " # === Plot 2: Decision Tree ===\n", " ax2.clear()\n", " ax2.set_xlim(0, 1)\n", " ax2.set_ylim(0, 1)\n", " ax2.axis('off')\n", " ax2.set_title('Decision Tree Structure', fontsize=16, fontweight='bold', pad=20)\n", " \n", " if self.root:\n", " self.draw_tree(self.root, ax2)\n", " else:\n", " ax2.text(0.5, 0.5, 'No splits yet\\nAdd splits to build the tree', \n", " ha='center', va='center', fontsize=14,\n", " bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))\n", " \n", " # === Plot 3: Statistics ===\n", " ax3.clear()\n", " ax3.axis('off')\n", " \n", " # Calculate overall statistics\n", " entropy_initial = self.calculate_entropy(self.y)\n", " gini_initial = self.calculate_gini(self.y)\n", " \n", " stats_text = \"DECISION TREE STATISTICS\\n\" + \"=\"*50 + \"\\n\\n\"\n", " stats_text += f\"Total Samples: {len(self.y)}\\n\"\n", " stats_text += f\" • Class 0: {np.sum(self.y == 0)}\\n\"\n", " stats_text += f\" • Class 1: {np.sum(self.y == 1)}\\n\\n\"\n", " stats_text += f\"Initial Impurity:\\n\"\n", " stats_text += f\" • Entropy: {entropy_initial:.4f}\\n\"\n", " stats_text += f\" • Gini: {gini_initial:.4f}\\n\\n\"\n", " \n", " if self.splits:\n", " stats_text += f\"Number of Splits: {len(self.splits)}\\n\\n\"\n", " stats_text += \"SPLIT SEQUENCE:\\n\" + \"-\"*50 + \"\\n\"\n", " \n", " for idx, (feature, threshold) in enumerate(self.splits):\n", " feature_name = \"Width (x)\" if feature == 'x' else \"Height (y)\"\n", " stats_text += f\"\\n{idx+1}. {feature_name} ≤ {threshold:.2f}\\n\"\n", " \n", " # Get leaf statistics\n", " leaves = []\n", " self._collect_leaves(self.root, leaves)\n", " \n", " if leaves:\n", " stats_text += f\"\\n\\nLEAF NODES: {len(leaves)}\\n\" + \"-\"*50 + \"\\n\"\n", " for idx, leaf in enumerate(leaves):\n", " stats_text += f\"\\nLeaf {idx+1}:\\n\"\n", " stats_text += f\" • Samples: {leaf.samples}\\n\"\n", " stats_text += f\" • Class 0: {leaf.class_counts.get(0, 0)} | \"\n", " stats_text += f\"Class 1: {leaf.class_counts.get(1, 0)}\\n\"\n", " stats_text += f\" • Prediction: Class {leaf.majority_class}\\n\"\n", " stats_text += f\" • Entropy: {leaf.entropy:.4f}\\n\"\n", " stats_text += f\" • Gini: {leaf.gini:.4f}\\n\"\n", " \n", " # Calculate weighted average impurity\n", " total_samples = sum(leaf.samples for leaf in leaves)\n", " avg_entropy = sum(leaf.entropy * leaf.samples for leaf in leaves) / total_samples\n", " avg_gini = sum(leaf.gini * leaf.samples for leaf in leaves) / total_samples\n", " \n", " stats_text += f\"\\n\\nWEIGHTED AVERAGE IMPURITY:\\n\" + \"-\"*50 + \"\\n\"\n", " stats_text += f\" • Entropy: {avg_entropy:.4f}\\n\"\n", " stats_text += f\" • Gini: {avg_gini:.4f}\\n\"\n", " stats_text += f\"\\nTOTAL INFORMATION GAIN:\\n\"\n", " stats_text += f\" • {entropy_initial - avg_entropy:.4f}\\n\"\n", " stats_text += f\"\\nTOTAL GINI REDUCTION:\\n\"\n", " stats_text += f\" • {gini_initial - avg_gini:.4f}\\n\"\n", " else:\n", " stats_text += \"No splits applied yet.\\n\"\n", " stats_text += \"\\nAdd splits in the format:\\n\"\n", " stats_text += \" feature, threshold\\n\\n\"\n", " stats_text += \"Example:\\n\"\n", " stats_text += \" x, 5.0\\n\"\n", " stats_text += \" y, 6.5\\n\"\n", " \n", " ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes,\n", " fontsize=10, verticalalignment='top',\n", " bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),\n", " family='monospace')\n", " \n", " plt.tight_layout()\n", " \n", " # Convert to image\n", " buf = io.BytesIO()\n", " plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')\n", " buf.seek(0)\n", " img = Image.open(buf)\n", " plt.close()\n", " \n", " return img\n", " \n", " def _collect_leaves(self, node, leaves):\n", " \"\"\"Collect all leaf nodes\"\"\"\n", " if node is None:\n", " return\n", " if node.is_leaf:\n", " leaves.append(node)\n", " else:\n", " self._collect_leaves(node.left, leaves)\n", " self._collect_leaves(node.right, leaves)\n", "\n", "# Create the partitioner\n", "partitioner = DecisionTreePartitioner()\n", "\n", "# Create Gradio interface\n", "with gr.Blocks(title=\"Multi-Split Decision Tree Visualizer\", theme=gr.themes.Soft()) as demo:\n", " gr.Markdown(\"\"\"\n", " # 🌳 Interactive Multi-Split Decision Tree Visualizer\n", " \n", " Build a decision tree step-by-step and visualize the partitioning process!\n", " \n", " \"\"\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " split_input = gr.Textbox(\n", " label=\"📝 Split Sequence (one per line: feature, threshold)\",\n", " placeholder=\"x, 5.0\\ny, 6.5\\nx, 3.0\",\n", " lines=10,\n", " value=\"x, 5.0\"\n", " )\n", " \n", " update_btn = gr.Button(\"🔄 Update Visualization\", variant=\"primary\", size=\"lg\")\n", " \n", " gr.Markdown(\"\"\"\n", " ### Example Splits:\n", " **Simple 2-split tree:**\n", " ```\n", " x, 5.0\n", " y, 6.5\n", " ```\n", " \n", " **Complex 4-split tree:**\n", " ```\n", " x, 5.0\n", " y, 6.5\n", " x, 3.0\n", " y, 8.0\n", " ```\n", " \"\"\")\n", " \n", " with gr.Column(scale=2):\n", " output_image = gr.Image(label=\"Visualization\", height=800)\n", " \n", " # Update visualization\n", " update_btn.click(\n", " fn=partitioner.visualize,\n", " inputs=[split_input],\n", " outputs=output_image\n", " )\n", " \n", " # Initial visualization\n", " demo.load(\n", " fn=partitioner.visualize,\n", " inputs=[split_input],\n", " outputs=output_image\n", " )\n", "\n", "# Launch the app\n", "demo.launch(share=True)" ] } ], "metadata": { "kernelspec": { "display_name": "WORK", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }