# Multi-Split Decision Tree Visualizer

This notebook creates an interactive Gradio app to visualize how decision trees partition the feature space with **multiple splits** and shows the complete **decision tree structure**.

## ‚ú® New Features:
- **Multiple Partitions**: Add as many splits as you want to build a complete tree
- **Decision Tree Visualization**: See the tree structure with all nodes and connections
- **Interactive Split Entry**: Add splits in a simple text format (feature, threshold)
- **Comprehensive Statistics**: Track entropy and Gini index for each node and leaf
- **Color-coded Visualization**: 
  - Blue arrows = "Yes" branch (‚â§ threshold)
  - Red arrows = "No" branch (> threshold)
  - Light blue leaves = Predicts Class 0 (Lemon)
  - Orange leaves = Predicts Class 1 (Orange)

## üìä Three-Panel Display:
1. **Top-Left**: Partitioned feature space with all split boundaries
2. **Bottom-Left**: Complete decision tree structure
3. **Right**: Detailed statistics and impurity measures

In [1]:
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
import io
from PIL import Image
from matplotlib.patches import FancyArrowPatch

class TreeNode:
    """Represents a node in the decision tree"""
    def __init__(self, depth=0, bounds=None):
        self.depth = depth
        self.bounds = bounds if bounds else {'x': (0, 10), 'y': (0, 10)}
        self.feature = None  # 'x' or 'y'
        self.threshold = None
        self.left = None
        self.right = None
        self.is_leaf = True
        self.samples = None
        self.class_counts = None
        self.entropy = None
        self.gini = None
        self.majority_class = None
        
class DecisionTreePartitioner:
    def __init__(self):
        self.reset_data()
        self.splits = []  # List of (feature, threshold) tuples
        self.root = None
        
    def reset_data(self):
        """Generate sample data with two classes"""
        np.random.seed(42)
        # Class 0 (blue) - bottom left
        n_samples = 50
        self.X0 = np.random.randn(n_samples, 2) * 1.5 + np.array([3, 3])
        # Class 1 (red) - top right  
        self.X1 = np.random.randn(n_samples, 2) * 1.5 + np.array([7, 7])
        
        self.X = np.vstack([self.X0, self.X1])
        self.y = np.hstack([np.zeros(n_samples), np.ones(n_samples)])
        self.splits = []
        self.root = None
        
    def calculate_entropy(self, y):
        """Calculate entropy for a set of labels"""
        if len(y) == 0:
            return 0
        _, counts = np.unique(y, return_counts=True)
        probabilities = counts / len(y)
        entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10))
        return entropy
    
    def calculate_gini(self, y):
        """Calculate Gini index for a set of labels"""
        if len(y) == 0:
            return 0
        _, counts = np.unique(y, return_counts=True)
        probabilities = counts / len(y)
        gini = 1 - np.sum(probabilities ** 2)
        return gini
    
    def build_tree_from_splits(self):
        """Build tree structure from the list of splits"""
        if not self.splits:
            return None
            
        self.root = TreeNode(depth=0)
        self._build_node(self.root, np.arange(len(self.y)), 0)
        return self.root
    
    def _build_node(self, node, indices, split_idx):
        """Recursively build tree nodes"""
        if len(indices) == 0:
            return
            
        # Calculate node statistics
        node.samples = len(indices)
        y_node = self.y[indices]
        unique, counts = np.unique(y_node, return_counts=True)
        node.class_counts = dict(zip(unique.astype(int), counts))
        node.entropy = self.calculate_entropy(y_node)
        node.gini = self.calculate_gini(y_node)
        node.majority_class = int(unique[np.argmax(counts)])
        
        # Check if we have more splits to apply
        if split_idx >= len(self.splits):
            node.is_leaf = True
            return
            
        # Apply the split
        feature, threshold = self.splits[split_idx]
        feature_idx = 0 if feature == 'x' else 1
        
        X_node = self.X[indices]
        left_mask = X_node[:, feature_idx] <= threshold
        right_mask = ~left_mask
        
        left_indices = indices[left_mask]
        right_indices = indices[right_mask]
        
        # Only create split if both children are non-empty
        if len(left_indices) > 0 and len(right_indices) > 0:
            node.is_leaf = False
            node.feature = feature
            node.threshold = threshold
            
            # Create child nodes with updated bounds
            left_bounds = node.bounds.copy()
            right_bounds = node.bounds.copy()
            
            if feature == 'x':
                left_bounds['x'] = (node.bounds['x'][0], threshold)
                right_bounds['x'] = (threshold, node.bounds['x'][1])
            else:
                left_bounds['y'] = (node.bounds['y'][0], threshold)
                right_bounds['y'] = (threshold, node.bounds['y'][1])
            
            node.left = TreeNode(depth=node.depth + 1, bounds=left_bounds)
            node.right = TreeNode(depth=node.depth + 1, bounds=right_bounds)
            
            # Recursively build children
            self._build_node(node.left, left_indices, split_idx + 1)
            self._build_node(node.right, right_indices, split_idx + 1)
    
    def add_split(self, feature, threshold):
        """Add a new split to the tree"""
        self.splits.append((feature, threshold))
        self.build_tree_from_splits()
        
    def remove_last_split(self):
        """Remove the last split"""
        if self.splits:
            self.splits.pop()
            if self.splits:
                self.build_tree_from_splits()
            else:
                self.root = None
    
    def draw_tree(self, node=None, ax=None, x=0.5, y=1.0, dx=0.25, level=0):
        """Recursively draw the decision tree"""
        if node is None:
            return
            
        # Node styling
        if node.is_leaf:
            box_color = 'lightblue' if node.majority_class == 0 else 'orange'
            alpha = 0.7
        else:
            box_color = 'lightgreen'
            alpha = 0.5
        
        # Create node text
        if node.is_leaf:
            text = f"Leaf\nClass: {node.majority_class}\n"
            text += f"Samples: {node.samples}\n"
            text += f"Entropy: {node.entropy:.3f}\n"
            text += f"Gini: {node.gini:.3f}"
        else:
            feature_name = "Width" if node.feature == 'x' else "Height"
            text = f"{feature_name} ‚â§ {node.threshold:.2f}\n"
            text += f"Samples: {node.samples}\n"
            text += f"Entropy: {node.entropy:.3f}\n"
            text += f"Gini: {node.gini:.3f}"
        
        # Draw box
        bbox = dict(boxstyle="round,pad=0.3", facecolor=box_color, 
                   edgecolor='black', linewidth=2, alpha=alpha)
        ax.text(x, y, text, ha='center', va='center', fontsize=8,
               bbox=bbox, fontweight='bold')
        
        # Draw connections to children
        if not node.is_leaf and node.left and node.right:
            # Left child
            y_child = y - 0.15
            x_left = x - dx
            x_right = x + dx
            
            # Draw arrows
            arrow_left = FancyArrowPatch((x, y - 0.05), (x_left, y_child + 0.05),
                                        arrowstyle='->', mutation_scale=20, 
                                        linewidth=2, color='blue')
            arrow_right = FancyArrowPatch((x, y - 0.05), (x_right, y_child + 0.05),
                                         arrowstyle='->', mutation_scale=20,
                                         linewidth=2, color='red')
            ax.add_patch(arrow_left)
            ax.add_patch(arrow_right)
            
            # Add Yes/No labels
            ax.text((x + x_left) / 2, (y + y_child) / 2, 'Yes', 
                   fontsize=9, color='blue', fontweight='bold')
            ax.text((x + x_right) / 2, (y + y_child) / 2, 'No',
                   fontsize=9, color='red', fontweight='bold')
            
            # Recursively draw children
            self.draw_tree(node.left, ax, x_left, y_child, dx * 0.5, level + 1)
            self.draw_tree(node.right, ax, x_right, y_child, dx * 0.5, level + 1)
    
    def visualize(self, split_history):
        """Create comprehensive visualization"""
        fig = plt.figure(figsize=(20, 10))
        gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1.2, 1])
        
        ax1 = fig.add_subplot(gs[0, 0])  # Partition view
        ax2 = fig.add_subplot(gs[1, 0])  # Decision tree
        ax3 = fig.add_subplot(gs[:, 1])  # Statistics
        
        # Parse split history
        self.splits = []
        if split_history.strip():
            for line in split_history.strip().split('\n'):
                if ',' in line:
                    parts = line.split(',')
                    if len(parts) == 2:
                        feature = parts[0].strip().lower()
                        try:
                            threshold = float(parts[1].strip())
                            self.splits.append((feature, threshold))
                        except ValueError:
                            pass
        
        # Build tree from splits
        if self.splits:
            self.build_tree_from_splits()
        
        # === Plot 1: Partitioned Feature Space ===
        ax1.scatter(self.X[self.y == 0, 0], self.X[self.y == 0, 1], 
                   c='blue', label='Class 0 (Lemon)', s=100, alpha=0.6, edgecolors='k')
        ax1.scatter(self.X[self.y == 1, 0], self.X[self.y == 1, 1], 
                   c='orange', label='Class 1 (Orange)', s=100, alpha=0.6, edgecolors='k')
        
        # Draw all partition lines
        colors = plt.cm.rainbow(np.linspace(0, 1, len(self.splits)))
        for idx, (feature, threshold) in enumerate(self.splits):
            if feature == 'x':
                ax1.axvline(x=threshold, color=colors[idx], linewidth=2.5, 
                           linestyle='--', label=f'Split {idx+1}: x‚â§{threshold:.1f}', alpha=0.8)
            else:
                ax1.axhline(y=threshold, color=colors[idx], linewidth=2.5,
                           linestyle='--', label=f'Split {idx+1}: y‚â§{threshold:.1f}', alpha=0.8)
        
        ax1.set_xlabel('Feature 1 (Width)', fontsize=14, fontweight='bold')
        ax1.set_ylabel('Feature 2 (Height)', fontsize=14, fontweight='bold')
        ax1.set_title('Partitioned Feature Space', fontsize=16, fontweight='bold')
        ax1.legend(fontsize=10, loc='upper left')
        ax1.grid(True, alpha=0.3)
        ax1.set_xlim(0, 10)
        ax1.set_ylim(0, 10)
        
        # === Plot 2: Decision Tree ===
        ax2.clear()
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
        ax2.axis('off')
        ax2.set_title('Decision Tree Structure', fontsize=16, fontweight='bold', pad=20)
        
        if self.root:
            self.draw_tree(self.root, ax2)
        else:
            ax2.text(0.5, 0.5, 'No splits yet\nAdd splits to build the tree', 
                    ha='center', va='center', fontsize=14,
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        # === Plot 3: Statistics ===
        ax3.clear()
        ax3.axis('off')
        
        # Calculate overall statistics
        entropy_initial = self.calculate_entropy(self.y)
        gini_initial = self.calculate_gini(self.y)
        
        stats_text = "DECISION TREE STATISTICS\n" + "="*50 + "\n\n"
        stats_text += f"Total Samples: {len(self.y)}\n"
        stats_text += f"  ‚Ä¢ Class 0: {np.sum(self.y == 0)}\n"
        stats_text += f"  ‚Ä¢ Class 1: {np.sum(self.y == 1)}\n\n"
        stats_text += f"Initial Impurity:\n"
        stats_text += f"  ‚Ä¢ Entropy: {entropy_initial:.4f}\n"
        stats_text += f"  ‚Ä¢ Gini: {gini_initial:.4f}\n\n"
        
        if self.splits:
            stats_text += f"Number of Splits: {len(self.splits)}\n\n"
            stats_text += "SPLIT SEQUENCE:\n" + "-"*50 + "\n"
            
            for idx, (feature, threshold) in enumerate(self.splits):
                feature_name = "Width (x)" if feature == 'x' else "Height (y)"
                stats_text += f"\n{idx+1}. {feature_name} ‚â§ {threshold:.2f}\n"
            
            # Get leaf statistics
            leaves = []
            self._collect_leaves(self.root, leaves)
            
            if leaves:
                stats_text += f"\n\nLEAF NODES: {len(leaves)}\n" + "-"*50 + "\n"
                for idx, leaf in enumerate(leaves):
                    stats_text += f"\nLeaf {idx+1}:\n"
                    stats_text += f"  ‚Ä¢ Samples: {leaf.samples}\n"
                    stats_text += f"  ‚Ä¢ Class 0: {leaf.class_counts.get(0, 0)} | "
                    stats_text += f"Class 1: {leaf.class_counts.get(1, 0)}\n"
                    stats_text += f"  ‚Ä¢ Prediction: Class {leaf.majority_class}\n"
                    stats_text += f"  ‚Ä¢ Entropy: {leaf.entropy:.4f}\n"
                    stats_text += f"  ‚Ä¢ Gini: {leaf.gini:.4f}\n"
                
                # Calculate weighted average impurity
                total_samples = sum(leaf.samples for leaf in leaves)
                avg_entropy = sum(leaf.entropy * leaf.samples for leaf in leaves) / total_samples
                avg_gini = sum(leaf.gini * leaf.samples for leaf in leaves) / total_samples
                
                stats_text += f"\n\nWEIGHTED AVERAGE IMPURITY:\n" + "-"*50 + "\n"
                stats_text += f"  ‚Ä¢ Entropy: {avg_entropy:.4f}\n"
                stats_text += f"  ‚Ä¢ Gini: {avg_gini:.4f}\n"
                stats_text += f"\nTOTAL INFORMATION GAIN:\n"
                stats_text += f"  ‚Ä¢ {entropy_initial - avg_entropy:.4f}\n"
                stats_text += f"\nTOTAL GINI REDUCTION:\n"
                stats_text += f"  ‚Ä¢ {gini_initial - avg_gini:.4f}\n"
        else:
            stats_text += "No splits applied yet.\n"
            stats_text += "\nAdd splits in the format:\n"
            stats_text += "  feature, threshold\n\n"
            stats_text += "Example:\n"
            stats_text += "  x, 5.0\n"
            stats_text += "  y, 6.5\n"
        
        ax3.text(0.05, 0.95, stats_text, transform=ax3.transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
                family='monospace')
        
        plt.tight_layout()
        
        # Convert to image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
        buf.seek(0)
        img = Image.open(buf)
        plt.close()
        
        return img
    
    def _collect_leaves(self, node, leaves):
        """Collect all leaf nodes"""
        if node is None:
            return
        if node.is_leaf:
            leaves.append(node)
        else:
            self._collect_leaves(node.left, leaves)
            self._collect_leaves(node.right, leaves)

# Create the partitioner
partitioner = DecisionTreePartitioner()

# Create Gradio interface
with gr.Blocks(title="Multi-Split Decision Tree Visualizer", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # üå≥ Interactive Multi-Split Decision Tree Visualizer
    
    Build a decision tree step-by-step and visualize the partitioning process!
    
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            split_input = gr.Textbox(
                label="üìù Split Sequence (one per line: feature, threshold)",
                placeholder="x, 5.0\ny, 6.5\nx, 3.0",
                lines=10,
                value="x, 5.0"
            )
            
            update_btn = gr.Button("üîÑ Update Visualization", variant="primary", size="lg")
            
            gr.Markdown("""
            ### Example Splits:
            **Simple 2-split tree:**
            ```
            x, 5.0
            y, 6.5
            ```
            
            **Complex 4-split tree:**
            ```
            x, 5.0
            y, 6.5
            x, 3.0
            y, 8.0
            ```
            """)
            
        with gr.Column(scale=2):
            output_image = gr.Image(label="Visualization", height=800)
    
    # Update visualization
    update_btn.click(
        fn=partitioner.visualize,
        inputs=[split_input],
        outputs=output_image
    )
    
    # Initial visualization
    demo.load(
        fn=partitioner.visualize,
        inputs=[split_input],
        outputs=output_image
    )

# Launch the app
demo.launch(share=True)

  from .autonotebook import tqdm as notebook_tqdm


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://4d58db9d9d6f8c53bc.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
* Running on public URL: https://4d58db9d9d6f8c53bc.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


