tree_decision / app.py
Huydinh1205's picture
Add application file
67098ed
raw
history blame
1.08 kB
from __future__ import annotations
import gradio as gr
import pandas as pd
import numpy as np
import pickle
from tree_decision import Node
# Load cây đã huấn luyện
with open("tree.pkl", "rb") as f:
tree, feature_names = pickle.load(f)
# Hàm convert dòng dữ liệu thành dict
def data_dict(row):
return {feature_names[i]: row[i] for i in range(len(feature_names))}
# Hàm xử lý CSV
def predict_csv(file):
df = pd.read_csv(file.name)
# Kiểm tra cột
for name in feature_names:
if name not in df.columns:
return f"❌ CSV must contain column '{name}'"
# Predict từng dòng
predictions = []
for i, row in df[feature_names].iterrows():
row_dict = data_dict(row.values)
pred = tree.search(row_dict)
predictions.append(pred)
df['prediction'] = predictions
return df
# Giao diện Gradio
demo = gr.Interface(
fn=predict_csv,
inputs=gr.File(label="Upload CSV"),
outputs=gr.Dataframe(label="Predictions")
)
if __name__ == "__main__":
demo.launch(share=True)