wjnwjn59 commited on
Commit
4dec27c
·
1 Parent(s): e08f4c0

modify illustration logic

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ __MACOSX/
3
+
4
+ .DS_Store
README.md CHANGED
@@ -12,7 +12,7 @@ license: "mit"
12
 
13
  # AIO2025 Module 03 - LightGBM Demo
14
 
15
- This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algorithms for both classification and regression tasks. The application provides a comprehensive interface for exploring efficient gradient boosting with leaf-wise tree growth where trees are trained sequentially to minimize gradient errors through dynamic visualizations and real-time parameter adjustment.
16
 
17
  ## ⚡ Features
18
 
@@ -26,13 +26,14 @@ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algor
26
  ### LightGBM Parameters
27
  - **Number of Trees**: Control gradient boosting iterations (limited to 1000 for performance)
28
  - **Learning Rate**: Step size shrinkage for gradient descent (0.001-1.0)
29
- - **Max Depth**: Individual tree depth (default: 6, leaf-wise growth)
 
30
  - **Early Stopping**: Automatic stopping when validation loss stops improving
31
 
32
  ### Visualizations
33
- - **Training Progress Chart**: Shows how loss evolves with early stopping during gradient boosting
34
- - **Individual Tree Visualization**: Detailed view of selected tree structure with leaf-wise growth
35
- - **Feature Importance**: Displays which features matter most using gradient-based importance
36
  - **LightGBM Process**: Gradient boosting aggregation display showing how predictions build up efficiently
37
 
38
  ## ⚡ Quick Start
@@ -61,7 +62,10 @@ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algor
61
  - `scikit-learn`: Data preprocessing utilities
62
  - `pandas`: Data manipulation
63
  - `numpy`: Numerical operations
64
- - `plotly`: Interactive visualizations
 
 
 
65
  - `gradio`: Web interface
66
 
67
  ### Architecture
@@ -75,6 +79,7 @@ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algor
75
  ### LightGBM Benefits
76
  - **Gradient Boosting**: Trees trained sequentially to minimize loss gradients
77
  - **High Performance**: Fast training and prediction with leaf-wise tree growth
 
78
  - **Feature Importance**: Robust importance scores through gradient-based methods
79
  - **Memory Efficiency**: Uses gradient-based one-side sampling (GOSS) and exclusive feature bundling (EFB)
80
  - **Early Stopping**: Automatic stopping when validation loss stops improving
@@ -101,7 +106,8 @@ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algor
101
 
102
  - **Number of Trees**: Limited to 1000 for optimal performance in this demo
103
  - **Learning Rate**: Default 0.1 works well; lower rates (0.01-0.05) create more conservative models, higher rates (0.2-0.3) for faster convergence
104
- - **Max Depth**: Default depth 6 balances performance and overfitting; deeper trees (8-12) for complex patterns
 
105
  - **Early Stopping**: Built-in early stopping prevents overfitting automatically
106
 
107
  ## 🎯 Use Cases
@@ -125,7 +131,8 @@ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algor
125
  - **Memory Efficient**: Optimized for gradient boosting with GOSS and EFB
126
  - **Real-time Updates**: Instant parameter adjustment and visualization
127
  - **Tree Selection**: Interactive dropdown to explore individual gradient boosting trees (up to 100)
128
- - **Gradient Nature**: Each tree fits gradients of loss function from previous iterations
 
129
 
130
  ## 🔗 Related Resources
131
 
 
12
 
13
  # AIO2025 Module 03 - LightGBM Demo
14
 
15
+ This interactive demo showcases LightGBM (Light Gradient Boosting Machine) algorithms for both classification and regression tasks. The application provides a comprehensive interface for exploring efficient gradient boosting with leaf-wise tree growth where trees are trained sequentially to minimize gradient errors through dynamic visualizations and real-time parameter adjustment. LightGBM uses leaf-wise tree growth instead of depth-wise growth for faster convergence and better performance.
16
 
17
  ## ⚡ Features
18
 
 
26
  ### LightGBM Parameters
27
  - **Number of Trees**: Control gradient boosting iterations (limited to 1000 for performance)
28
  - **Learning Rate**: Step size shrinkage for gradient descent (0.001-1.0)
29
+ - **Number of Leaves**: Maximum number of leaves in one tree (default: 31, controls complexity)
30
+ - **Min Data in Leaf**: Minimum number of data points in one leaf (default: 20, prevents overfitting)
31
  - **Early Stopping**: Automatic stopping when validation loss stops improving
32
 
33
  ### Visualizations
34
+ - **Interactive Training Progress Chart**: Interactive Plotly chart showing how loss evolves with early stopping during gradient boosting
35
+ - **Interactive Feature Importance**: Interactive Plotly bar chart displaying which features matter most using gradient-based importance
36
+ - **Individual Tree Visualization**: Detailed view of selected tree structure with leaf-wise growth using matplotlib
37
  - **LightGBM Process**: Gradient boosting aggregation display showing how predictions build up efficiently
38
 
39
  ## ⚡ Quick Start
 
62
  - `scikit-learn`: Data preprocessing utilities
63
  - `pandas`: Data manipulation
64
  - `numpy`: Numerical operations
65
+ - `plotly`: Interactive visualizations for charts
66
+ - `matplotlib`: Static visualizations for tree plots
67
+ - `graphviz`: Tree structure visualization
68
+ - `Pillow`: Image processing
69
  - `gradio`: Web interface
70
 
71
  ### Architecture
 
79
  ### LightGBM Benefits
80
  - **Gradient Boosting**: Trees trained sequentially to minimize loss gradients
81
  - **High Performance**: Fast training and prediction with leaf-wise tree growth
82
+ - **Leaf-wise Growth**: Grows trees leaf-by-leaf instead of level-by-level for faster convergence
83
  - **Feature Importance**: Robust importance scores through gradient-based methods
84
  - **Memory Efficiency**: Uses gradient-based one-side sampling (GOSS) and exclusive feature bundling (EFB)
85
  - **Early Stopping**: Automatic stopping when validation loss stops improving
 
106
 
107
  - **Number of Trees**: Limited to 1000 for optimal performance in this demo
108
  - **Learning Rate**: Default 0.1 works well; lower rates (0.01-0.05) create more conservative models, higher rates (0.2-0.3) for faster convergence
109
+ - **Number of Leaves**: Default 31 works well; for depth-7 equivalent, use ~70-80 leaves instead of 127 to prevent overfitting
110
+ - **Min Data in Leaf**: Default 20 prevents overfitting; increase to hundreds or thousands for large datasets
111
  - **Early Stopping**: Built-in early stopping prevents overfitting automatically
112
 
113
  ## 🎯 Use Cases
 
131
  - **Memory Efficient**: Optimized for gradient boosting with GOSS and EFB
132
  - **Real-time Updates**: Instant parameter adjustment and visualization
133
  - **Tree Selection**: Interactive dropdown to explore individual gradient boosting trees (up to 100)
134
+ - **Leaf-wise Growth**: LightGBM uses leaf-wise tree growth for faster convergence compared to depth-wise growth
135
+ - **Parameter Tuning**: num_leaves is the main parameter to control tree complexity; min_data_in_leaf prevents overfitting
136
 
137
  ## 🔗 Related Resources
138
 
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -28,6 +28,8 @@ vlai_template.configure(
28
  )
29
 
30
  current_dataframe = None
 
 
31
 
32
  def load_sample_data_fallback(dataset_choice="Iris"):
33
  """Fallback data loading function when LightGBM is not available"""
@@ -301,8 +303,8 @@ def update_configuration(df_preview, target_col):
301
  # AdaBoost-specific functions
302
 
303
 
304
- def execute_prediction(df_preview, target_col, n_estimators, max_depth, learning_rate, train_test_split_ratio, show_split_info, *input_values):
305
- global current_dataframe
306
  df = current_dataframe
307
 
308
  EMPTY_PLOT = None
@@ -321,6 +323,10 @@ def execute_prediction(df_preview, target_col, n_estimators, max_depth, learning
321
  is_valid, validation_msg, problem_type = validate_config(df, target_col)
322
  if not is_valid:
323
  return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, error_style.format("Configuration issue."), default_dropdown)
 
 
 
 
324
 
325
  try:
326
  if LIGHTGBM_AVAILABLE:
@@ -338,12 +344,12 @@ def execute_prediction(df_preview, target_col, n_estimators, max_depth, learning
338
  new_point_dict[comp["name"]] = v
339
 
340
  boosting_progress_fig, loss_chart_fig, importance_fig, prediction, pred_details, summary, aggregation_display = lightgbm_core.run_lightgbm_and_visualize(
341
- df, target_col, new_point_dict, n_estimators, max_depth, learning_rate, train_test_split_ratio, problem_type
342
  )
343
 
344
  feature_cols = [c for c in df.columns if c != target_col]
345
  first_tree_fig = lightgbm_core.get_individual_tree_visualization(
346
- lightgbm_core._get_current_model(), 0, feature_cols, problem_type
347
  )
348
 
349
  updated_tree_selector = update_tree_selector_choices(n_estimators)
@@ -356,30 +362,67 @@ def execute_prediction(df_preview, target_col, n_estimators, max_depth, learning
356
 
357
 
358
  def update_tree_selector_choices(n_estimators):
359
- # Limit tree visualization dropdown to 50 trees for UI performance
360
- n_estimators_limited = min(int(n_estimators), 50)
361
- choices = [f"Tree {i+1}" for i in range(n_estimators_limited)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  return gr.Dropdown(choices=choices, value="Tree 1")
363
 
364
 
365
- def update_tree_visualization(tree_selector):
366
- global current_dataframe
367
 
368
  if current_dataframe is None or current_dataframe.empty:
369
  return None
370
 
 
 
 
371
  try:
372
  model = lightgbm_core._get_current_model()
373
  if model is None:
374
  return None
375
 
376
  tree_index = int(tree_selector.split()[-1]) - 1
377
- _, _, problem_type = validate_config(current_dataframe, current_dataframe.columns[-1])
378
- feature_cols = [c for c in current_dataframe.columns if c != current_dataframe.columns[-1]]
379
- tree_fig = lightgbm_core.get_individual_tree_visualization(model, tree_index, feature_cols, problem_type)
 
 
 
380
 
381
  return tree_fig
382
  except Exception as e:
 
383
  return None
384
 
385
 
@@ -416,7 +459,7 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
416
  n_estimators = gr.Number(
417
  label="Number of Trees",
418
  value=100, minimum=1, maximum=1000, precision=0,
419
- info="Number of gradient boosting trees (up to 1000)"
420
  )
421
  learning_rate = gr.Slider(
422
  label="Learning Rate",
@@ -424,10 +467,15 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
424
  info="Step size shrinkage for each tree"
425
  )
426
  with gr.Row():
427
- max_depth = gr.Number(
428
- label="Max Depth",
429
- value=6, minimum=1, maximum=15, precision=0,
430
- info="Maximum depth of individual trees (-1 for unlimited, but 6 is typical)"
 
 
 
 
 
431
  )
432
 
433
  gr.Markdown("**📊 Data Split Configuration**")
@@ -442,6 +490,18 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
442
  value=True,
443
  info="Display train/validation set information"
444
  )
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  inputs_group = gr.Group(visible=False)
447
  with inputs_group:
@@ -476,15 +536,17 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
476
  feature_importance_plot = gr.Plot(label="Feature Importance", visible=True)
477
  aggregation_display = gr.HTML("**⚡ LightGBM Process**<br><br>LightGBM details will appear here showing how the prediction builds up.", label="⚡ LightGBM Process")
478
 
479
- gr.Markdown("""⚡ **LightGBM Tips**:
480
  - **📉 Loss Evolution Chart**: Monitor training and validation loss to understand model convergence with early stopping.
481
  - **🌳 Individual Tree Visualization**: Select any tree to see its leaf-wise structure and contribution.
482
  - **📊 Feature Importance**: Displays which features are most influential using gradient-based importance.
483
  - **🎯 Parameter Tuning**: Try different **number of trees** (up to 1000) and **learning rate** (0.001-1.0).
484
  - **⚡ Learning Rate**: Default 0.1 works well; lower values (0.01-0.05) for more conservative models, higher values (0.2-0.3) for faster convergence.
485
- - **🌲 Tree Depth**: Default depth 6 balances complexity and performance; deeper trees (8-12) for complex patterns.
486
- - **🎯 Gradient Boosting**: LightGBM uses gradient-based one-side sampling and exclusive feature bundling for efficiency.
 
487
  - **🔍 Tree Analysis**: Use the tree selector to understand how each tree contributes to gradient boosting ensemble.
 
488
  """)
489
 
490
  vlai_template.create_footer()
@@ -515,13 +577,13 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
515
 
516
  run_prediction_btn.click(
517
  fn=execute_prediction,
518
- inputs=[data_preview, target_column, n_estimators, max_depth, learning_rate, train_test_split_ratio, show_split_info] + input_components,
519
  outputs=[loss_chart, individual_tree_plot, feature_importance_plot, aggregation_display, tree_selector],
520
  )
521
 
522
  tree_selector.change(
523
  fn=update_tree_visualization,
524
- inputs=[tree_selector],
525
  outputs=[individual_tree_plot],
526
  )
527
 
 
28
  )
29
 
30
  current_dataframe = None
31
+ current_target_column = None
32
+ current_problem_type = None
33
 
34
  def load_sample_data_fallback(dataset_choice="Iris"):
35
  """Fallback data loading function when LightGBM is not available"""
 
303
  # AdaBoost-specific functions
304
 
305
 
306
+ def execute_prediction(df_preview, target_col, n_estimators, num_leaves, min_data_in_leaf, learning_rate, train_test_split_ratio, show_split_info, use_early_stopping, early_stopping_rounds, *input_values):
307
+ global current_dataframe, current_target_column, current_problem_type
308
  df = current_dataframe
309
 
310
  EMPTY_PLOT = None
 
323
  is_valid, validation_msg, problem_type = validate_config(df, target_col)
324
  if not is_valid:
325
  return (EMPTY_PLOT, EMPTY_PLOT, EMPTY_PLOT, error_style.format("Configuration issue."), default_dropdown)
326
+
327
+ # Store the current target column and problem type globally
328
+ current_target_column = target_col
329
+ current_problem_type = problem_type
330
 
331
  try:
332
  if LIGHTGBM_AVAILABLE:
 
344
  new_point_dict[comp["name"]] = v
345
 
346
  boosting_progress_fig, loss_chart_fig, importance_fig, prediction, pred_details, summary, aggregation_display = lightgbm_core.run_lightgbm_and_visualize(
347
+ df, target_col, new_point_dict, n_estimators, num_leaves, min_data_in_leaf, learning_rate, train_test_split_ratio, problem_type, use_early_stopping, early_stopping_rounds
348
  )
349
 
350
  feature_cols = [c for c in df.columns if c != target_col]
351
  first_tree_fig = lightgbm_core.get_individual_tree_visualization(
352
+ lightgbm_core._get_current_model(), 0, feature_cols, problem_type, num_leaves
353
  )
354
 
355
  updated_tree_selector = update_tree_selector_choices(n_estimators)
 
362
 
363
 
364
  def update_tree_selector_choices(n_estimators):
365
+ # Only show trees that were actually trained (respect early stopping)
366
+ try:
367
+ model = lightgbm_core._get_current_model()
368
+ actual_trees = 0
369
+ if model is not None:
370
+ # Prefer evals_result_ count if available
371
+ if hasattr(model, 'evals_result_') and model.evals_result_:
372
+ eval_results = model.evals_result_
373
+ if 'train' in eval_results and eval_results['train']:
374
+ metric_name = list(eval_results['train'].keys())[0]
375
+ actual_trees = len(eval_results['train'][metric_name])
376
+ print(f"Tree selector: eval history reports {actual_trees} trees trained")
377
+ # Fallback to best_iteration if present
378
+ if actual_trees == 0 and hasattr(model, 'best_iteration') and model.best_iteration is not None:
379
+ actual_trees = int(model.best_iteration) + 1
380
+ print(f"Tree selector: using best_iteration -> {actual_trees} trees")
381
+ # Final fallback to model.num_trees()
382
+ if actual_trees == 0 and hasattr(model, 'num_trees'):
383
+ actual_trees = int(model.num_trees())
384
+ print(f"Tree selector: using num_trees() -> {actual_trees} trees")
385
+
386
+ # Ensure at least one option to avoid empty dropdown
387
+ actual_trees = max(1, actual_trees)
388
+ # For UI performance, cap at 100
389
+ trees_to_show = min(actual_trees, 100)
390
+
391
+ # Debug
392
+ print(f"Tree selector: requested={n_estimators}, available={actual_trees}, showing={trees_to_show}")
393
+ except Exception as e:
394
+ trees_to_show = min(max(1, int(n_estimators)), 100)
395
+ print(f"Tree selector error: {e}, falling back to requested count {trees_to_show}")
396
+
397
+ choices = [f"Tree {i+1}" for i in range(trees_to_show)]
398
  return gr.Dropdown(choices=choices, value="Tree 1")
399
 
400
 
401
+ def update_tree_visualization(tree_selector, num_leaves=31):
402
+ global current_dataframe, current_target_column, current_problem_type
403
 
404
  if current_dataframe is None or current_dataframe.empty:
405
  return None
406
 
407
+ if current_target_column is None or current_problem_type is None:
408
+ return None
409
+
410
  try:
411
  model = lightgbm_core._get_current_model()
412
  if model is None:
413
  return None
414
 
415
  tree_index = int(tree_selector.split()[-1]) - 1
416
+
417
+ # Use the stored target column and problem type
418
+ feature_cols = [c for c in current_dataframe.columns if c != current_target_column]
419
+
420
+ # Use the num_leaves parameter from the UI
421
+ tree_fig = lightgbm_core.get_individual_tree_visualization(model, tree_index, feature_cols, current_problem_type, num_leaves)
422
 
423
  return tree_fig
424
  except Exception as e:
425
+ print(f"Tree visualization error: {str(e)}") # For debugging
426
  return None
427
 
428
 
 
459
  n_estimators = gr.Number(
460
  label="Number of Trees",
461
  value=100, minimum=1, maximum=1000, precision=0,
462
+ info="Requested number of trees (up to 1000). Actual trained trees may be fewer due to early stopping."
463
  )
464
  learning_rate = gr.Slider(
465
  label="Learning Rate",
 
467
  info="Step size shrinkage for each tree"
468
  )
469
  with gr.Row():
470
+ num_leaves = gr.Number(
471
+ label="Number of Leaves",
472
+ value=31, minimum=2, maximum=127, precision=0,
473
+ info="Maximum number of leaves in one tree (controls complexity, typically 31-70)"
474
+ )
475
+ min_data_in_leaf = gr.Number(
476
+ label="Min Data in Leaf",
477
+ value=20, minimum=1, maximum=1000, precision=0,
478
+ info="Minimum number of data points in one leaf (prevents overfitting)"
479
  )
480
 
481
  gr.Markdown("**📊 Data Split Configuration**")
 
490
  value=True,
491
  info="Display train/validation set information"
492
  )
493
+
494
+ with gr.Row():
495
+ use_early_stopping = gr.Checkbox(
496
+ label="Use Early Stopping",
497
+ value=True,
498
+ info="Stop training early if validation performance doesn't improve (prevents overfitting)"
499
+ )
500
+ early_stopping_rounds = gr.Number(
501
+ label="Early Stopping Rounds",
502
+ value=20, minimum=5, maximum=100, precision=0,
503
+ info="Number of rounds to wait before stopping (20% of trees by default)"
504
+ )
505
 
506
  inputs_group = gr.Group(visible=False)
507
  with inputs_group:
 
536
  feature_importance_plot = gr.Plot(label="Feature Importance", visible=True)
537
  aggregation_display = gr.HTML("**⚡ LightGBM Process**<br><br>LightGBM details will appear here showing how the prediction builds up.", label="⚡ LightGBM Process")
538
 
539
+ gr.Markdown("""⚡ **LightGBM Leaf-wise Tree Tips**:
540
  - **📉 Loss Evolution Chart**: Monitor training and validation loss to understand model convergence with early stopping.
541
  - **🌳 Individual Tree Visualization**: Select any tree to see its leaf-wise structure and contribution.
542
  - **📊 Feature Importance**: Displays which features are most influential using gradient-based importance.
543
  - **🎯 Parameter Tuning**: Try different **number of trees** (up to 1000) and **learning rate** (0.001-1.0).
544
  - **⚡ Learning Rate**: Default 0.1 works well; lower values (0.01-0.05) for more conservative models, higher values (0.2-0.3) for faster convergence.
545
+ - **🍃 Number of Leaves**: Controls tree complexity (default 31). For depth-7 equivalent, use ~70-80 leaves instead of 127 to prevent overfitting.
546
+ - **📊 Min Data in Leaf**: Prevents overfitting by requiring minimum samples per leaf (default 20). Increase for larger datasets.
547
+ - **🎯 Leaf-wise Growth**: LightGBM grows trees leaf-by-leaf for faster convergence compared to depth-wise growth.
548
  - **🔍 Tree Analysis**: Use the tree selector to understand how each tree contributes to gradient boosting ensemble.
549
+ - **⏹️ Early Stopping**: Tree selector shows requested trees, but only actually trained trees can be visualized. Check console for actual vs requested tree counts.
550
  """)
551
 
552
  vlai_template.create_footer()
 
577
 
578
  run_prediction_btn.click(
579
  fn=execute_prediction,
580
+ inputs=[data_preview, target_column, n_estimators, num_leaves, min_data_in_leaf, learning_rate, train_test_split_ratio, show_split_info, use_early_stopping, early_stopping_rounds] + input_components,
581
  outputs=[loss_chart, individual_tree_plot, feature_importance_plot, aggregation_display, tree_selector],
582
  )
583
 
584
  tree_selector.change(
585
  fn=update_tree_visualization,
586
+ inputs=[tree_selector, num_leaves],
587
  outputs=[individual_tree_plot],
588
  )
589
 
requirements.txt CHANGED
@@ -2,5 +2,8 @@ gradio>=5.38.0
2
  pandas>=1.5.0
3
  scikit-learn>=1.3.0
4
  numpy>=1.24.0
5
- plotly>=5.15.0
6
- lightgbm>=4.0.0
 
 
 
 
2
  pandas>=1.5.0
3
  scikit-learn>=1.3.0
4
  numpy>=1.24.0
5
+ lightgbm>=4.0.0
6
+ matplotlib>=3.5.0
7
+ graphviz>=0.20.0
8
+ Pillow>=8.0.0
9
+ plotly>=5.15.0
src/__pycache__/lightgbm_core.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/lightgbm_core.cpython-312.pyc and b/src/__pycache__/lightgbm_core.cpython-312.pyc differ
 
src/lightgbm_core.py CHANGED
@@ -1,5 +1,7 @@
1
  import pandas as pd
2
  import numpy as np
 
 
3
 
4
  import lightgbm as lgb
5
  from sklearn.preprocessing import LabelEncoder
@@ -8,8 +10,20 @@ from sklearn.datasets import (
8
  )
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import accuracy_score, mean_squared_error
 
11
  import plotly.graph_objects as go
12
  import plotly.express as px
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  _current_model = None
15
 
@@ -151,7 +165,7 @@ def preprocess_data(df, target_col, new_point_dict):
151
 
152
 
153
  def run_lightgbm_and_visualize(df, target_col, new_point_dict,
154
- n_estimators, max_depth, learning_rate, train_test_split_ratio=0.8, problem_type=None):
155
  X, y, new_point, feature_cols, _ = preprocess_data(df, target_col, new_point_dict)
156
 
157
  if problem_type is None:
@@ -159,8 +173,10 @@ def run_lightgbm_and_visualize(df, target_col, new_point_dict,
159
 
160
  if n_estimators < 1:
161
  return None, None, None, None, "Number of estimators must be ≥ 1.", None
162
- if max_depth is not None and max_depth < 1:
163
- return None, None, None, None, "Max depth must be ≥ 1.", None
 
 
164
  if learning_rate <= 0 or learning_rate > 1:
165
  return None, None, None, None, "Learning rate must be between 0 and 1.", None
166
 
@@ -175,8 +191,8 @@ def run_lightgbm_and_visualize(df, target_col, new_point_dict,
175
  'objective': 'multiclass' if problem_type == "classification" and len(np.unique(y)) > 2 else 'binary' if problem_type == "classification" else 'regression',
176
  'num_class': len(np.unique(y)) if problem_type == "classification" and len(np.unique(y)) > 2 else None,
177
  'boosting_type': 'gbdt',
178
- 'num_leaves': 2**max_depth - 1 if max_depth else 31,
179
- 'max_depth': int(max_depth) if max_depth else -1,
180
  'learning_rate': float(learning_rate),
181
  'feature_fraction': 0.9,
182
  'bagging_fraction': 0.8,
@@ -193,17 +209,76 @@ def run_lightgbm_and_visualize(df, target_col, new_point_dict,
193
  train_data = lgb.Dataset(X_train, label=y_train)
194
  val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
195
 
196
- # Train model with early stopping
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  model = lgb.train(
198
  params,
199
  train_data,
200
  valid_sets=[train_data, val_data],
201
  valid_names=['train', 'eval'],
202
  num_boost_round=n_estimators,
203
- callbacks=[lgb.early_stopping(stopping_rounds=50, verbose=False), lgb.log_evaluation(0)]
204
  )
205
 
206
- prediction = model.predict(new_point, num_iteration=model.best_iteration)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if problem_type == "classification":
208
  if len(np.unique(y)) == 2: # Binary classification
209
  prediction = int(prediction > 0.5)
@@ -246,15 +321,48 @@ def run_lightgbm_and_visualize(df, target_col, new_point_dict,
246
  loss_chart_fig = create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type)
247
  importance_fig = create_feature_importance_plot(model, feature_cols)
248
  prediction_details = create_prediction_details(model, new_point[0], feature_cols, target_col, prediction, problem_type)
249
- summary = create_algorithm_summary(model, problem_type, n_estimators, max_depth, learning_rate, feature_cols)
250
  aggregation_display = create_lightgbm_aggregation_display(model, new_point[0], problem_type, target_col, df, split_info)
251
 
252
  return None, loss_chart_fig, importance_fig, prediction, prediction_details, summary, aggregation_display
253
 
254
 
255
  def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
256
- """Create a loss chart showing training and validation loss evolution during LightGBM training"""
257
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Get evaluation results from LightGBM training history
259
  eval_results = model.evals_result_
260
 
@@ -274,8 +382,9 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
274
  y=train_losses,
275
  mode='lines+markers',
276
  name='Training Loss',
277
- line=dict(color='#8E44AD', width=2),
278
- marker=dict(size=4)
 
279
  ))
280
 
281
  # Plot validation loss
@@ -284,8 +393,9 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
284
  y=val_losses,
285
  mode='lines+markers',
286
  name='Validation Loss',
287
- line=dict(color='#3498DB', width=2),
288
- marker=dict(size=4)
 
289
  ))
290
 
291
  # Add early stopping line if available
@@ -294,7 +404,9 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
294
  x=model.best_iteration + 1,
295
  line_dash="dash",
296
  line_color="red",
297
- annotation_text="Early Stop"
 
 
298
  )
299
 
300
  fig.update_layout(
@@ -302,7 +414,8 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
302
  xaxis_title="Boosting Round",
303
  yaxis_title=metric_name.replace('_', ' ').title(),
304
  plot_bgcolor="white",
305
- height=400,
 
306
  legend=dict(
307
  yanchor="top",
308
  y=0.99,
@@ -331,7 +444,7 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
331
  )
332
  fig.update_layout(
333
  title="LightGBM Training Progress - Loss Evolution",
334
- height=400,
335
  plot_bgcolor="white"
336
  )
337
  return fig
@@ -339,204 +452,333 @@ def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
339
 
340
 
341
 
342
- def create_individual_tree_visualization(model, tree_index, feature_cols, problem_type):
343
- """Create visualization of individual LightGBM tree"""
344
  try:
345
- # LightGBM doesn't expose individual trees easily, so create a representative visualization
346
- if tree_index < model.num_trees():
347
- return create_lightgbm_tree_plot(tree_index, feature_cols, problem_type, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  else:
349
- raise IndexError(f"Tree index {tree_index} out of range")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  except Exception as e:
352
- # Fallback visualization
353
- fig = go.Figure()
354
- fig.add_annotation(
355
- text=f"LightGBM Tree {tree_index + 1} Visualization<br>Unable to extract tree structure<br>Error: {str(e)}",
356
- xref="paper", yref="paper",
357
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
358
- showarrow=False,
359
- font=dict(size=14)
360
- )
361
- fig.update_layout(
362
- title=f"LightGBM Tree {tree_index + 1} Structure",
363
- height=500,
364
- plot_bgcolor="white"
365
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  return fig
 
 
 
 
 
 
367
 
368
 
369
- def create_lightgbm_tree_plot(tree_index, feature_cols, problem_type, model):
370
  """Create tree visualization for LightGBM trees"""
371
  try:
 
 
 
372
  # Create a representative visualization for LightGBM tree
373
- return create_manual_tree_plot(tree_index, feature_cols, problem_type, "LightGBM", 1.0, model)
374
 
375
  except Exception as e:
376
  # Fallback to manual tree creation
377
- return create_manual_tree_plot(tree_index, feature_cols, problem_type, "LightGBM", 1.0)
378
 
379
 
380
- def create_manual_tree_plot(tree_index, feature_cols, problem_type, model_type, weight=1.0, model=None):
381
  """Create a manual tree visualization when tree structure is not easily accessible"""
382
- fig = go.Figure()
383
 
384
- # Create a sample tree structure for demonstration
385
  import random
386
  random.seed(tree_index) # Consistent trees for same index
387
 
388
- # For LightGBM, we can try to get some parameters
389
- if model_type == "LightGBM" and model:
 
 
390
  try:
391
- # Try to get max_depth from model params
392
- actual_depth = model.params.get('max_depth', 6) if hasattr(model, 'params') else 6
393
- if actual_depth == -1: # LightGBM default unlimited depth
394
- actual_depth = 6 # Set reasonable default for visualization
395
  except:
396
- actual_depth = 6 # LightGBM typical depth
397
  else:
398
- actual_depth = 1 # fallback for other models
399
 
400
- # Root node
401
  root_feature = random.choice(feature_cols) if feature_cols else "feature_0"
402
  root_threshold = round(random.uniform(0.1, 5.0), 2)
403
 
404
- # Create tree structure based on actual depth
405
- if actual_depth <= 2 or model_type != "LightGBM":
406
- # Simple tree (depth 1-2)
407
- positions = {
408
- 'root': (0, 1),
409
- 'left': (-1, 0),
410
- 'right': (1, 0)
411
- }
412
-
413
- if model_type == "LightGBM":
414
- labels = {
415
- 'root': f"{root_feature}<br>≤ {root_threshold}<br>Tree: {tree_index + 1}<br>Gradient Boosting",
416
- 'left': f"Leaf (≤)<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}",
417
- 'right': f"Leaf (>)<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}"
418
- }
419
- else:
420
- labels = {
421
- 'root': f"{root_feature}<br>≤ {root_threshold}<br>Weight: {weight:.3f}<br>Decision Stump",
422
- 'left': f"Leaf (≤)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}",
423
- 'right': f"Leaf (>)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}"
424
- }
425
-
426
- colors = {
427
- 'root': '#8E44AD' if model_type == "LightGBM" else '#81C784', # Purple for LightGBM, Green for others
428
- 'left': '#3498DB' if model_type == "LightGBM" else '#FFB74D', # Blue for LightGBM, Orange for others
429
- 'right': '#3498DB' if model_type == "LightGBM" else '#FFB74D' # Blue for LightGBM, Orange for others
430
- }
431
-
432
- edges = [('root', 'left'), ('root', 'right')]
433
- title_suffix = "Gradient Boosting Tree" if model_type == "LightGBM" else "Decision Stump"
434
-
435
- else:
436
- # Deeper tree (depth 2+)
437
- positions = {
438
- 'root': (0, 2),
439
- 'left': (-1.5, 1),
440
- 'right': (1.5, 1),
441
- 'left_left': (-2.5, 0),
442
- 'left_right': (-0.5, 0),
443
- 'right_left': (0.5, 0),
444
- 'right_right': (2.5, 0)
445
- }
446
-
447
- if model_type == "LightGBM":
448
- labels = {
449
- 'root': f"{root_feature}<br>≤ {root_threshold}<br>Tree: {tree_index + 1}<br>Depth: {actual_depth}",
450
- 'left': f"{random.choice(feature_cols) if feature_cols else 'feature_1'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
451
- 'right': f"{random.choice(feature_cols) if feature_cols else 'feature_2'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
452
- 'left_left': f"Leaf<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: 25",
453
- 'left_right': f"Leaf<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: 50",
454
- 'right_left': f"Leaf<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: 30",
455
- 'right_right': f"Leaf<br>Output: {round(random.uniform(-1, 1), 3)}<br>Samples: 45"
456
- }
457
-
458
- colors = {
459
- 'root': '#8E44AD', 'left': '#8E44AD', 'right': '#8E44AD', # Purple for split nodes
460
- 'left_left': '#3498DB', 'left_right': '#3498DB', 'right_left': '#3498DB', 'right_right': '#3498DB' # Blue for leaves
461
- }
462
- else:
463
- labels = {
464
- 'root': f"{root_feature}<br>≤ {root_threshold}<br>Weight: {weight:.3f}<br>Depth: {actual_depth}",
465
- 'left': f"{random.choice(feature_cols) if feature_cols else 'feature_1'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
466
- 'right': f"{random.choice(feature_cols) if feature_cols else 'feature_2'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
467
- 'left_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 25",
468
- 'left_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 50",
469
- 'right_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 30",
470
- 'right_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 45"
471
- }
472
-
473
- colors = {
474
- 'root': '#81C784', 'left': '#81C784', 'right': '#81C784', # Green for split nodes
475
- 'left_left': '#FFB74D', 'left_right': '#FFB74D', 'right_left': '#FFB74D', 'right_right': '#FFB74D' # Orange for leaves
476
- }
477
-
478
- edges = [
479
- ('root', 'left'), ('root', 'right'),
480
- ('left', 'left_left'), ('left', 'left_right'),
481
- ('right', 'right_left'), ('right', 'right_right')
482
- ]
483
- title_suffix = f"Depth {actual_depth} Gradient Boosting Tree" if model_type == "LightGBM" else f"Depth {actual_depth} Tree"
484
 
485
- edge_x, edge_y = [], []
486
- for parent, child in edges:
487
- parent_pos = positions[parent]
488
- child_pos = positions[child]
489
- edge_x.extend([parent_pos[0], child_pos[0], None])
490
- edge_y.extend([parent_pos[1], child_pos[1], None])
491
 
492
- fig.add_trace(go.Scatter(
493
- x=edge_x, y=edge_y,
494
- mode='lines',
495
- line=dict(color='gray', width=2),
496
- showlegend=False,
497
- hoverinfo='none'
498
- ))
499
 
500
- # Draw nodes
501
- for node_id, (x, y) in positions.items():
502
- fig.add_trace(go.Scatter(
503
- x=[x], y=[y],
504
- mode='markers+text',
505
- marker=dict(
506
- size=35,
507
- color=colors[node_id],
508
- line=dict(width=2, color='darkblue'),
509
- symbol='circle'
510
- ),
511
- text=labels[node_id],
512
- textposition='middle center',
513
- textfont=dict(size=9, color='black'),
514
- showlegend=False,
515
- hoverinfo='text',
516
- hovertext=labels[node_id]
517
- ))
518
 
519
- # Adjust layout based on tree depth
520
- if actual_depth == 1:
521
- x_range, y_range, height = [-1.5, 1.5], [-0.5, 1.5], 400
522
- else:
523
- x_range, y_range, height = [-3, 3], [-0.5, 2.5], 600
524
 
525
- fig.update_layout(
526
- title=f"{model_type} Estimator {tree_index + 1} Structure - {title_suffix} ({problem_type.title()})",
527
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=x_range),
528
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=y_range),
529
- plot_bgcolor="white",
530
- height=height,
531
- margin=dict(l=40, r=40, t=60, b=40),
532
- showlegend=False
533
- )
534
 
 
 
 
 
 
535
  return fig
536
 
537
 
538
- def get_individual_tree_visualization(model, tree_index, feature_cols, problem_type):
539
- return create_individual_tree_visualization(model, tree_index, feature_cols, problem_type)
540
 
541
 
542
  def create_feature_importance_plot(model, feature_cols):
@@ -545,27 +787,51 @@ def create_feature_importance_plot(model, feature_cols):
545
  importances = model.feature_importance(importance_type='gain')
546
  order = np.argsort(importances)[::-1]
547
 
 
 
 
 
548
  fig = go.Figure()
549
- fig.add_trace(
550
- go.Bar(
551
- x=[feature_cols[i] for i in order],
552
- y=importances[order],
553
- text=[f"{importances[i]:.0f}" for i in order],
 
554
  textposition="auto",
555
- marker_color="#8E44AD", # LightGBM purple theme
556
- hovertemplate="<b>%{x}</b><br>Importance: %{y:.0f}<extra></extra>",
557
- )
558
- )
 
 
559
  fig.update_layout(
560
  title="LightGBM Feature Importance (Gain)",
561
  xaxis_title="Features",
562
  yaxis_title="Importance Score",
563
  plot_bgcolor="white",
564
- height=400,
 
565
  margin=dict(l=40, r=40, t=60, b=40),
 
 
 
 
 
 
 
 
 
 
 
566
  )
567
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
568
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
 
 
 
 
 
569
  return fig
570
  except:
571
  fig = go.Figure()
@@ -578,7 +844,7 @@ def create_feature_importance_plot(model, feature_cols):
578
  )
579
  fig.update_layout(
580
  title="LightGBM Feature Importance",
581
- height=400,
582
  plot_bgcolor="white"
583
  )
584
  return fig
@@ -607,15 +873,16 @@ def create_prediction_details(model, new_point, feature_cols, target_col, predic
607
  return f"Predicted Value: {prediction:.3f}"
608
 
609
 
610
- def create_algorithm_summary(model, problem_type, n_estimators, max_depth, learning_rate, feature_cols):
611
  num_trees = model.num_trees() if hasattr(model, 'num_trees') else n_estimators
612
  return f"""
613
  **LightGBM {problem_type.title()} Model Summary:**
614
  - Trees Built: {num_trees}
615
- - Max Depth: {max_depth if max_depth != -1 else 'Unlimited'}
 
616
  - Learning Rate: {learning_rate}
617
  - Features: {len(feature_cols)}
618
- - Algorithm: Gradient Boosting (LightGBM)
619
  """
620
 
621
 
 
1
  import pandas as pd
2
  import numpy as np
3
+ import io
4
+ import base64
5
 
6
  import lightgbm as lgb
7
  from sklearn.preprocessing import LabelEncoder
 
10
  )
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.metrics import accuracy_score, mean_squared_error
13
+ # Import Plotly for interactive charts
14
  import plotly.graph_objects as go
15
  import plotly.express as px
16
+ import matplotlib.pyplot as plt
17
+ import matplotlib
18
+ matplotlib.use('Agg') # Use non-interactive backend
19
+
20
+ # Add graphviz import for tree visualization
21
+ try:
22
+ import graphviz
23
+ GRAPHVIZ_AVAILABLE = True
24
+ except ImportError:
25
+ GRAPHVIZ_AVAILABLE = False
26
+ print("Warning: graphviz not available. Tree visualization will use fallback methods.")
27
 
28
  _current_model = None
29
 
 
165
 
166
 
167
  def run_lightgbm_and_visualize(df, target_col, new_point_dict,
168
+ n_estimators, num_leaves, min_data_in_leaf, learning_rate, train_test_split_ratio=0.8, problem_type=None, use_early_stopping=True, early_stopping_rounds=20):
169
  X, y, new_point, feature_cols, _ = preprocess_data(df, target_col, new_point_dict)
170
 
171
  if problem_type is None:
 
173
 
174
  if n_estimators < 1:
175
  return None, None, None, None, "Number of estimators must be ≥ 1.", None
176
+ if num_leaves < 2:
177
+ return None, None, None, None, "Number of leaves must be ≥ 2.", None
178
+ if min_data_in_leaf < 1:
179
+ return None, None, None, None, "Min data in leaf must be ≥ 1.", None
180
  if learning_rate <= 0 or learning_rate > 1:
181
  return None, None, None, None, "Learning rate must be between 0 and 1.", None
182
 
 
191
  'objective': 'multiclass' if problem_type == "classification" and len(np.unique(y)) > 2 else 'binary' if problem_type == "classification" else 'regression',
192
  'num_class': len(np.unique(y)) if problem_type == "classification" and len(np.unique(y)) > 2 else None,
193
  'boosting_type': 'gbdt',
194
+ 'num_leaves': int(num_leaves), # Main parameter to control tree complexity
195
+ 'min_data_in_leaf': int(min_data_in_leaf), # Important parameter to prevent overfitting
196
  'learning_rate': float(learning_rate),
197
  'feature_fraction': 0.9,
198
  'bagging_fraction': 0.8,
 
209
  train_data = lgb.Dataset(X_train, label=y_train)
210
  val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
211
 
212
+ # Custom callback to capture evaluation results
213
+ evals_result = {}
214
+
215
+ def record_eval(env):
216
+ """Custom callback to record evaluation results"""
217
+ if 'train' not in evals_result:
218
+ evals_result['train'] = {}
219
+ evals_result['eval'] = {}
220
+
221
+ # Get the metric name from the first evaluation
222
+ if env.evaluation_result_list:
223
+ metric_name = env.evaluation_result_list[0][1] # Get metric name from first result
224
+
225
+ if metric_name not in evals_result['train']:
226
+ evals_result['train'][metric_name] = []
227
+ evals_result['eval'][metric_name] = []
228
+
229
+ # Record both training and validation results
230
+ for eval_name, eval_metric, eval_result, _ in env.evaluation_result_list:
231
+ if eval_name == 'train':
232
+ evals_result['train'][eval_metric].append(eval_result)
233
+ elif eval_name == 'eval':
234
+ evals_result['eval'][eval_metric].append(eval_result)
235
+
236
+ # Train model with configurable early stopping
237
+ callbacks = [lgb.log_evaluation(0), record_eval]
238
+ if use_early_stopping:
239
+ # Use user-specified early stopping rounds, but ensure it's reasonable
240
+ stopping_rounds = min(early_stopping_rounds, max(10, int(n_estimators * 0.2)))
241
+ callbacks.append(lgb.early_stopping(stopping_rounds=stopping_rounds, verbose=False))
242
+ print(f"Training with early stopping: {stopping_rounds} rounds")
243
+ else:
244
+ print(f"Training without early stopping: {n_estimators} rounds")
245
+
246
+ # Train the model with evaluation sets
247
  model = lgb.train(
248
  params,
249
  train_data,
250
  valid_sets=[train_data, val_data],
251
  valid_names=['train', 'eval'],
252
  num_boost_round=n_estimators,
253
+ callbacks=callbacks
254
  )
255
 
256
+ # Store evaluation results in the model
257
+ model.evals_result_ = evals_result
258
+
259
+ # Debug information
260
+ print(f"Training completed. Model has evals_result_: {hasattr(model, 'evals_result_')}")
261
+ print(f"Custom evals_result captured: {bool(evals_result)}")
262
+ if evals_result:
263
+ print(f"Custom evaluation results keys: {list(evals_result.keys())}")
264
+ if 'train' in evals_result:
265
+ print(f"Train metrics: {list(evals_result['train'].keys())}")
266
+ if evals_result['train']:
267
+ metric_name = list(evals_result['train'].keys())[0]
268
+ print(f"Train {metric_name} values count: {len(evals_result['train'][metric_name])}")
269
+ if 'eval' in evals_result:
270
+ print(f"Eval metrics: {list(evals_result['eval'].keys())}")
271
+ if evals_result['eval']:
272
+ metric_name = list(evals_result['eval'].keys())[0]
273
+ print(f"Eval {metric_name} values count: {len(evals_result['eval'][metric_name])}")
274
+ else:
275
+ print("No evaluation results captured by custom callback")
276
+
277
+ # Use best iteration if early stopping was used, otherwise use all trees
278
+ if use_early_stopping and hasattr(model, 'best_iteration'):
279
+ prediction = model.predict(new_point, num_iteration=model.best_iteration)[0]
280
+ else:
281
+ prediction = model.predict(new_point)[0]
282
  if problem_type == "classification":
283
  if len(np.unique(y)) == 2: # Binary classification
284
  prediction = int(prediction > 0.5)
 
321
  loss_chart_fig = create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type)
322
  importance_fig = create_feature_importance_plot(model, feature_cols)
323
  prediction_details = create_prediction_details(model, new_point[0], feature_cols, target_col, prediction, problem_type)
324
+ summary = create_algorithm_summary(model, problem_type, n_estimators, num_leaves, min_data_in_leaf, learning_rate, feature_cols)
325
  aggregation_display = create_lightgbm_aggregation_display(model, new_point[0], problem_type, target_col, df, split_info)
326
 
327
  return None, loss_chart_fig, importance_fig, prediction, prediction_details, summary, aggregation_display
328
 
329
 
330
  def create_loss_chart(model, X_train, y_train, X_val, y_val, problem_type):
331
+ """Create an interactive loss chart showing training and validation loss evolution during LightGBM training"""
332
  try:
333
+ # Debug information
334
+ print(f"Loss chart: Model has evals_result_ attribute: {hasattr(model, 'evals_result_')}")
335
+ if hasattr(model, 'evals_result_'):
336
+ print(f"Loss chart: evals_result_ content: {model.evals_result_}")
337
+ if model.evals_result_:
338
+ print(f"Loss chart: evals_result_ keys: {list(model.evals_result_.keys())}")
339
+ if 'train' in model.evals_result_:
340
+ print(f"Loss chart: train keys: {list(model.evals_result_['train'].keys())}")
341
+ if 'eval' in model.evals_result_:
342
+ print(f"Loss chart: eval keys: {list(model.evals_result_['eval'].keys())}")
343
+ else:
344
+ print("Loss chart: evals_result_ is empty")
345
+ else:
346
+ print("Loss chart: Model does not have evals_result_ attribute")
347
+
348
+ # Check if model has evaluation results
349
+ if not hasattr(model, 'evals_result_') or not model.evals_result_:
350
+ # If no evaluation results, show a message instead of simulated data
351
+ fig = go.Figure()
352
+ fig.add_annotation(
353
+ text="No training history available<br>Run training with validation data to see loss evolution",
354
+ xref="paper", yref="paper",
355
+ x=0.5, y=0.5, xanchor='center', yanchor='middle',
356
+ showarrow=False,
357
+ font=dict(size=14)
358
+ )
359
+ fig.update_layout(
360
+ title="LightGBM Training Progress - Loss Evolution",
361
+ height=500,
362
+ plot_bgcolor="white"
363
+ )
364
+ return fig
365
+
366
  # Get evaluation results from LightGBM training history
367
  eval_results = model.evals_result_
368
 
 
382
  y=train_losses,
383
  mode='lines+markers',
384
  name='Training Loss',
385
+ line=dict(color='#8E44AD', width=3),
386
+ marker=dict(size=6, color='#8E44AD'),
387
+ hovertemplate='<b>Training Loss</b><br>Round: %{x}<br>Loss: %{y:.4f}<extra></extra>'
388
  ))
389
 
390
  # Plot validation loss
 
393
  y=val_losses,
394
  mode='lines+markers',
395
  name='Validation Loss',
396
+ line=dict(color='#3498DB', width=3),
397
+ marker=dict(size=6, color='#3498DB'),
398
+ hovertemplate='<b>Validation Loss</b><br>Round: %{x}<br>Loss: %{y:.4f}<extra></extra>'
399
  ))
400
 
401
  # Add early stopping line if available
 
404
  x=model.best_iteration + 1,
405
  line_dash="dash",
406
  line_color="red",
407
+ line_width=2,
408
+ annotation_text=f"Best Iteration ({model.best_iteration + 1})",
409
+ annotation_position="top"
410
  )
411
 
412
  fig.update_layout(
 
414
  xaxis_title="Boosting Round",
415
  yaxis_title=metric_name.replace('_', ' ').title(),
416
  plot_bgcolor="white",
417
+ height=500,
418
+ hovermode='x unified',
419
  legend=dict(
420
  yanchor="top",
421
  y=0.99,
 
444
  )
445
  fig.update_layout(
446
  title="LightGBM Training Progress - Loss Evolution",
447
+ height=500,
448
  plot_bgcolor="white"
449
  )
450
  return fig
 
452
 
453
 
454
 
455
+ def create_individual_tree_visualization(model, tree_index, feature_cols, problem_type, num_leaves=None):
456
+ """Create visualization of individual LightGBM tree using multiple methods with fallback"""
457
  try:
458
+ # Check if model is valid
459
+ if model is None:
460
+ raise Exception("Model is None - please run prediction first")
461
+
462
+ # Check if model has the required attributes
463
+ if not hasattr(model, 'num_trees'):
464
+ raise Exception("Model does not have num_trees attribute")
465
+
466
+ # Check if tree index is valid - use actual trees trained, not just best iteration
467
+ actual_trees = model.num_trees()
468
+ if hasattr(model, 'evals_result_') and model.evals_result_:
469
+ eval_results = model.evals_result_
470
+ if 'train' in eval_results and eval_results['train']:
471
+ metric_name = list(eval_results['train'].keys())[0]
472
+ actual_trees = len(eval_results['train'][metric_name])
473
+
474
+ if tree_index >= actual_trees:
475
+ # If tree index is beyond what was actually trained, show a message
476
+ raise IndexError(f"Tree {tree_index + 1} was not trained. Only {actual_trees} trees were actually trained. Best iteration was {model.best_iteration + 1 if hasattr(model, 'best_iteration') else 'unknown'}.")
477
+
478
+ # Try multiple visualization methods in order of preference
479
+ try:
480
+ # Method 1: Try lightgbm.plot_tree first (as requested by user)
481
+ return create_lightgbm_native_tree_plot(model, tree_index, feature_cols, problem_type, num_leaves)
482
+ except Exception as plot_error:
483
+ print(f"Native plot failed: {plot_error}") # Debug info
484
+ try:
485
+ # Method 2: Try lightgbm.create_tree_digraph as fallback (best quality)
486
+ return create_lightgbm_digraph_tree_plot(model, tree_index, feature_cols, problem_type, num_leaves)
487
+ except Exception as digraph_error:
488
+ print(f"Digraph plot failed: {digraph_error}") # Debug info
489
+ try:
490
+ # Method 3: Fallback to manual visualization
491
+ return create_lightgbm_tree_plot(tree_index, feature_cols, problem_type, model, num_leaves)
492
+ except Exception as manual_error:
493
+ print(f"Manual plot failed: {manual_error}") # Debug info
494
+ raise Exception(f"All tree visualization methods failed: {manual_error}")
495
+
496
+ except Exception as e:
497
+ # Final fallback visualization with better error message
498
+ fig, ax = plt.subplots(figsize=(12, 8), dpi=100)
499
+ error_msg = str(e)
500
+ if "out of range" in error_msg:
501
+ # Get actual trees trained for better error message
502
+ actual_trees = model.num_trees() if model and hasattr(model, 'num_trees') else 0
503
+ if model and hasattr(model, 'evals_result_') and model.evals_result_:
504
+ eval_results = model.evals_result_
505
+ if 'train' in eval_results and eval_results['train']:
506
+ metric_name = list(eval_results['train'].keys())[0]
507
+ actual_trees = len(eval_results['train'][metric_name])
508
+
509
+ best_iteration = model.best_iteration + 1 if model and hasattr(model, 'best_iteration') else 'unknown'
510
+ display_msg = f"Tree {tree_index + 1} was not trained.\nOnly {actual_trees} trees were actually trained.\nBest iteration was {best_iteration}.\nPlease select a tree from 1 to {actual_trees}."
511
  else:
512
+ display_msg = f"Unable to visualize Tree {tree_index + 1}\nError: {error_msg}"
513
+
514
+ ax.text(0.5, 0.5, display_msg, ha='center', va='center', fontsize=14, color='red', transform=ax.transAxes)
515
+ ax.set_title(f"LightGBM Tree {tree_index + 1} Structure", fontsize=16, fontweight='bold')
516
+ ax.set_xlim(0, 1)
517
+ ax.set_ylim(0, 1)
518
+ ax.axis('off')
519
+
520
+ plt.tight_layout()
521
+ return fig
522
+
523
+
524
+ def create_lightgbm_digraph_tree_plot(model, tree_index, feature_cols, problem_type, num_leaves=None):
525
+ """Create tree visualization using lightgbm.create_tree_digraph for better tree structure"""
526
+ try:
527
+ # Check if model has the required number of trees - use actual trees trained
528
+ if not hasattr(model, 'num_trees'):
529
+ raise Exception("Model does not have num_trees attribute")
530
+
531
+ actual_trees = model.num_trees()
532
+ if hasattr(model, 'evals_result_') and model.evals_result_:
533
+ eval_results = model.evals_result_
534
+ if 'train' in eval_results and eval_results['train']:
535
+ metric_name = list(eval_results['train'].keys())[0]
536
+ actual_trees = len(eval_results['train'][metric_name])
537
+
538
+ if tree_index >= actual_trees:
539
+ raise Exception(f"Tree {tree_index + 1} was not trained. Only {actual_trees} trees were actually trained. Best iteration was {model.best_iteration + 1 if hasattr(model, 'best_iteration') else 'unknown'}.")
540
+
541
+ # Check if graphviz is available
542
+ if not GRAPHVIZ_AVAILABLE:
543
+ raise Exception("graphviz not available for tree visualization")
544
+
545
+ # Create tree digraph using LightGBM's native function
546
+ try:
547
+ # Use lightgbm.create_tree_digraph to create the tree structure
548
+ dot_data = lgb.create_tree_digraph(
549
+ model,
550
+ tree_index=tree_index,
551
+ show_info=['split_gain', 'internal_value', 'internal_count', 'leaf_count'],
552
+ precision=3
553
+ )
554
+ except Exception as digraph_error:
555
+ # Try with simpler parameters
556
+ try:
557
+ dot_data = lgb.create_tree_digraph(
558
+ model,
559
+ tree_index=tree_index,
560
+ show_info=['split_gain', 'internal_count'],
561
+ precision=2
562
+ )
563
+ except Exception as simple_error:
564
+ # Try with minimal parameters
565
+ dot_data = lgb.create_tree_digraph(
566
+ model,
567
+ tree_index=tree_index
568
+ )
569
+
570
+ # Convert dot data to matplotlib figure
571
+ try:
572
+ # Render the graph to PNG format
573
+ png_data = dot_data.pipe(format='png')
574
+
575
+ # Create a matplotlib figure and display the image
576
+ fig, ax = plt.subplots(figsize=(20, 12), dpi=150)
577
+
578
+ # Load the PNG data and display it
579
+ from PIL import Image
580
+ import io as io_module
581
+
582
+ image = Image.open(io_module.BytesIO(png_data))
583
+ ax.imshow(image)
584
+ ax.axis('off') # Hide axes
585
+
586
+ # Add title and information
587
+ ax.set_title(f'LightGBM Tree {tree_index + 1} - {problem_type.title()} (Using lightgbm.create_tree_digraph)',
588
+ fontsize=18, fontweight='bold', pad=20, color='#8E44AD')
589
+
590
+ # Add num_leaves information if available
591
+ if num_leaves:
592
+ ax.text(0.02, 0.98, f'Max Leaves: {num_leaves}',
593
+ transform=ax.transAxes, fontsize=12,
594
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7),
595
+ verticalalignment='top')
596
+
597
+ # Add tree information
598
+ ax.text(0.98, 0.98, f'Tree Index: {tree_index + 1}\nTotal Trees: {model.num_trees()}',
599
+ transform=ax.transAxes, fontsize=10,
600
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7),
601
+ verticalalignment='top', horizontalalignment='right')
602
+
603
+ plt.tight_layout()
604
+
605
+ return fig
606
+
607
+ except Exception as render_error:
608
+ raise Exception(f"Failed to render tree digraph: {str(render_error)}")
609
 
610
  except Exception as e:
611
+ # If lightgbm.create_tree_digraph fails, raise the error to trigger fallback
612
+ raise Exception(f"lightgbm.create_tree_digraph failed: {str(e)}")
613
+
614
+
615
+ def create_lightgbm_native_tree_plot(model, tree_index, feature_cols, problem_type, num_leaves=None):
616
+ """Create tree visualization using lightgbm.plot_tree native functionality"""
617
+ try:
618
+ # Check if model has the required number of trees - use actual trees trained
619
+ if not hasattr(model, 'num_trees'):
620
+ raise Exception("Model does not have num_trees attribute")
621
+
622
+ actual_trees = model.num_trees()
623
+ if hasattr(model, 'evals_result_') and model.evals_result_:
624
+ eval_results = model.evals_result_
625
+ if 'train' in eval_results and eval_results['train']:
626
+ metric_name = list(eval_results['train'].keys())[0]
627
+ actual_trees = len(eval_results['train'][metric_name])
628
+
629
+ if tree_index >= actual_trees:
630
+ raise Exception(f"Tree {tree_index + 1} was not trained. Only {actual_trees} trees were actually trained. Best iteration was {model.best_iteration + 1 if hasattr(model, 'best_iteration') else 'unknown'}.")
631
+
632
+ # Create a matplotlib figure with higher DPI for better quality
633
+ fig, ax = plt.subplots(figsize=(20, 12), dpi=150)
634
+
635
+ # Use lightgbm.plot_tree to create the tree visualization
636
+ # Try with different parameter combinations for better compatibility
637
+ try:
638
+ # First try with comprehensive information
639
+ lgb.plot_tree(
640
+ model,
641
+ tree_index=tree_index,
642
+ ax=ax,
643
+ show_info=['split_gain', 'internal_value', 'internal_count', 'leaf_count'],
644
+ precision=3,
645
+ figsize=(20, 12)
646
+ )
647
+ except Exception as plot_error:
648
+ print(f"Comprehensive plot failed: {plot_error}")
649
+ # Try with simpler parameters
650
+ try:
651
+ lgb.plot_tree(
652
+ model,
653
+ tree_index=tree_index,
654
+ ax=ax,
655
+ show_info=['split_gain', 'internal_count'],
656
+ precision=2,
657
+ figsize=(20, 12)
658
+ )
659
+ except Exception as simple_error:
660
+ print(f"Simple plot failed: {simple_error}")
661
+ # Try with minimal parameters
662
+ try:
663
+ lgb.plot_tree(
664
+ model,
665
+ tree_index=tree_index,
666
+ ax=ax,
667
+ figsize=(20, 12)
668
+ )
669
+ except Exception as minimal_error:
670
+ print(f"Minimal plot failed: {minimal_error}")
671
+ # Try without figsize parameter
672
+ lgb.plot_tree(
673
+ model,
674
+ tree_index=tree_index,
675
+ ax=ax
676
+ )
677
+
678
+ # Customize the plot
679
+ ax.set_title(f'LightGBM Tree {tree_index + 1} - {problem_type.title()} (Using lightgbm.plot_tree)',
680
+ fontsize=18, fontweight='bold', pad=20, color='#8E44AD')
681
+
682
+ # Add num_leaves information if available
683
+ if num_leaves:
684
+ ax.text(0.02, 0.98, f'Max Leaves: {num_leaves}',
685
+ transform=ax.transAxes, fontsize=12,
686
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7),
687
+ verticalalignment='top')
688
+
689
+ # Add tree information
690
+ ax.text(0.98, 0.98, f'Tree Index: {tree_index}\nTotal Trees: {model.num_trees()}',
691
+ transform=ax.transAxes, fontsize=10,
692
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7),
693
+ verticalalignment='top', horizontalalignment='right')
694
+
695
+ # Adjust layout
696
+ plt.tight_layout()
697
+
698
+ # Return the matplotlib figure directly (no Plotly)
699
  return fig
700
+
701
+ except Exception as e:
702
+ # Log the error for debugging
703
+ print(f"Native plot failed: {str(e)}")
704
+ # If lightgbm.plot_tree fails, raise the error to trigger fallback
705
+ raise Exception(f"lightgbm.plot_tree failed: {str(e)}")
706
 
707
 
708
+ def create_lightgbm_tree_plot(tree_index, feature_cols, problem_type, model, num_leaves=None):
709
  """Create tree visualization for LightGBM trees"""
710
  try:
711
+ # Use provided num_leaves or get from model params
712
+ if num_leaves is None:
713
+ num_leaves = model.params.get('num_leaves', 31) if hasattr(model, 'params') else 31
714
  # Create a representative visualization for LightGBM tree
715
+ return create_manual_tree_plot(tree_index, feature_cols, problem_type, "LightGBM", 1.0, model, num_leaves)
716
 
717
  except Exception as e:
718
  # Fallback to manual tree creation
719
+ return create_manual_tree_plot(tree_index, feature_cols, problem_type, "LightGBM", 1.0, None, num_leaves or 31)
720
 
721
 
722
+ def create_manual_tree_plot(tree_index, feature_cols, problem_type, model_type, weight=1.0, model=None, num_leaves=None):
723
  """Create a manual tree visualization when tree structure is not easily accessible"""
724
+ fig, ax = plt.subplots(figsize=(12, 8), dpi=100)
725
 
726
+ # Create a simple tree visualization
727
  import random
728
  random.seed(tree_index) # Consistent trees for same index
729
 
730
+ # Determine actual number of leaves to use
731
+ if num_leaves is not None:
732
+ actual_leaves = int(num_leaves)
733
+ elif model_type == "LightGBM" and model:
734
  try:
735
+ actual_leaves = model.params.get('num_leaves', 31) if hasattr(model, 'params') else 31
 
 
 
736
  except:
737
+ actual_leaves = 31
738
  else:
739
+ actual_leaves = 31
740
 
741
+ # Simple tree structure
742
  root_feature = random.choice(feature_cols) if feature_cols else "feature_0"
743
  root_threshold = round(random.uniform(0.1, 5.0), 2)
744
 
745
+ # Create a simple tree diagram
746
+ ax.text(0.5, 0.9, f"{model_type} Tree {tree_index + 1}",
747
+ ha='center', va='center', fontsize=16, fontweight='bold', transform=ax.transAxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
 
749
+ ax.text(0.5, 0.7, f"Root: {root_feature} ≤ {root_threshold}",
750
+ ha='center', va='center', fontsize=14, transform=ax.transAxes,
751
+ bbox=dict(boxstyle="round,pad=0.3", facecolor='#8E44AD', alpha=0.7))
 
 
 
752
 
753
+ ax.text(0.2, 0.4, f"Left Leaf\nOutput: {round(random.uniform(-1, 1), 3)}\nSamples: {random.randint(20, 80)}",
754
+ ha='center', va='center', fontsize=12, transform=ax.transAxes,
755
+ bbox=dict(boxstyle="round,pad=0.3", facecolor='#3498DB', alpha=0.7))
 
 
 
 
756
 
757
+ ax.text(0.8, 0.4, f"Right Leaf\nOutput: {round(random.uniform(-1, 1), 3)}\nSamples: {random.randint(20, 80)}",
758
+ ha='center', va='center', fontsize=12, transform=ax.transAxes,
759
+ bbox=dict(boxstyle="round,pad=0.3", facecolor='#3498DB', alpha=0.7))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
+ # Draw arrows
762
+ ax.annotate('', xy=(0.2, 0.5), xytext=(0.4, 0.7),
763
+ arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
764
+ ax.annotate('', xy=(0.8, 0.5), xytext=(0.6, 0.7),
765
+ arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
766
 
767
+ # Add tree info
768
+ title_suffix = f"Leaf-wise Tree ({actual_leaves} leaves)" if model_type == "LightGBM" else "Decision Tree"
769
+ ax.text(0.5, 0.1, f"{title_suffix} - {problem_type.title()}",
770
+ ha='center', va='center', fontsize=12, transform=ax.transAxes)
 
 
 
 
 
771
 
772
+ ax.set_xlim(0, 1)
773
+ ax.set_ylim(0, 1)
774
+ ax.axis('off')
775
+
776
+ plt.tight_layout()
777
  return fig
778
 
779
 
780
+ def get_individual_tree_visualization(model, tree_index, feature_cols, problem_type, num_leaves=None):
781
+ return create_individual_tree_visualization(model, tree_index, feature_cols, problem_type, num_leaves)
782
 
783
 
784
  def create_feature_importance_plot(model, feature_cols):
 
787
  importances = model.feature_importance(importance_type='gain')
788
  order = np.argsort(importances)[::-1]
789
 
790
+ # Prepare data for Plotly
791
+ sorted_features = [feature_cols[i] for i in order]
792
+ sorted_importances = importances[order]
793
+
794
  fig = go.Figure()
795
+
796
+ # Create interactive bar plot
797
+ fig.add_trace(go.Bar(
798
+ x=sorted_features,
799
+ y=sorted_importances,
800
+ text=[f"{imp:.0f}" for imp in sorted_importances],
801
  textposition="auto",
802
+ marker_color='#8E44AD',
803
+ marker_line=dict(color='#6C3483', width=1),
804
+ hovertemplate='<b>%{x}</b><br>Importance: %{y:.0f}<extra></extra>',
805
+ name='Feature Importance'
806
+ ))
807
+
808
  fig.update_layout(
809
  title="LightGBM Feature Importance (Gain)",
810
  xaxis_title="Features",
811
  yaxis_title="Importance Score",
812
  plot_bgcolor="white",
813
+ height=500,
814
+ hovermode='closest',
815
  margin=dict(l=40, r=40, t=60, b=40),
816
+ xaxis=dict(
817
+ tickangle=45,
818
+ showgrid=True,
819
+ gridwidth=1,
820
+ gridcolor='lightgray'
821
+ ),
822
+ yaxis=dict(
823
+ showgrid=True,
824
+ gridwidth=1,
825
+ gridcolor='lightgray'
826
+ )
827
  )
828
+
829
+ # Add interactive features
830
+ fig.update_traces(
831
+ marker_line_width=1,
832
+ marker_line_color='#6C3483'
833
+ )
834
+
835
  return fig
836
  except:
837
  fig = go.Figure()
 
844
  )
845
  fig.update_layout(
846
  title="LightGBM Feature Importance",
847
+ height=500,
848
  plot_bgcolor="white"
849
  )
850
  return fig
 
873
  return f"Predicted Value: {prediction:.3f}"
874
 
875
 
876
+ def create_algorithm_summary(model, problem_type, n_estimators, num_leaves, min_data_in_leaf, learning_rate, feature_cols):
877
  num_trees = model.num_trees() if hasattr(model, 'num_trees') else n_estimators
878
  return f"""
879
  **LightGBM {problem_type.title()} Model Summary:**
880
  - Trees Built: {num_trees}
881
+ - Number of Leaves: {num_leaves}
882
+ - Min Data in Leaf: {min_data_in_leaf}
883
  - Learning Rate: {learning_rate}
884
  - Features: {len(feature_cols)}
885
+ - Algorithm: Leaf-wise Gradient Boosting (LightGBM)
886
  """
887
 
888