tree_decision / tree_decision.py
Huydinh1205's picture
other file
46a9203
raw
history blame
2.71 kB
from __future__ import annotations
import numpy as np
def compute_gini(y):
m = len(y)
return 1 - sum((np.bincount(y.astype(int)) / m) ** 2)
def split_node(feature, y):
m = len(y)
best_gini = float("inf")
best_average = None
feature_sorted = np.sort(feature)
for index in range(m - 1):
average = (feature_sorted[index] + feature_sorted[index + 1]) / 2
y_left = y[feature <= average]
y_right = y[feature > average]
gini_left = compute_gini(y_left)
gini_right = compute_gini(y_right)
gini = (len(y_left) / m) * gini_left + (len(y_right) / m) * gini_right
if gini < best_gini:
best_gini = gini
best_average = average
return best_average, best_gini
class Node:
def __init__(self, feature=None, branch=None, value=None):
self.feature = feature
self.branch = branch
self.node_children = []
self.is_leaf = False
self.value = value
def __str__(self):
return f"Feature: {self.feature}, Branch: {self.branch}, Value: {self.value}, Leaf: {self.is_leaf}"
def add_child(self, node):
self.node_children.append(node)
def set_leaf(self, value):
self.is_leaf = value
def search(self, x_dict):
if self.is_leaf:
return self.value
if x_dict[self.feature] < self.branch:
return self.node_children[0].search(x_dict)
else:
return self.node_children[1].search(x_dict)
def construct_decision_tree(x, y, feature_names):
if len(np.unique(y)) == 1:
leaf = Node(value=y[0])
leaf.set_leaf(True)
return leaf
if feature_names.size == 0:
leaf = Node(value=np.bincount(y.astype(int)).argmax())
leaf.set_leaf(True)
return leaf
split_values_gini = [split_node(x[:, i], y) for i in range(x.shape[1])]
best_feature_index = np.argmin([g for _, g in split_values_gini])
split_value = split_values_gini[best_feature_index][0]
feature_name = feature_names[best_feature_index]
x_left, y_left, x_right, y_right = [], [], [], []
for i in range(len(y)):
row = x[i]
if row[best_feature_index] <= split_value:
x_left.append(row)
y_left.append(y[i])
else:
x_right.append(row)
y_right.append(y[i])
x_left, y_left = np.array(x_left), np.array(y_left)
x_right, y_right = np.array(x_right), np.array(y_right)
node = Node(feature=feature_name, branch=split_value)
node.add_child(construct_decision_tree(x_left, y_left, feature_names))
node.add_child(construct_decision_tree(x_right, y_right, feature_names))
return node