Huydinh1205 commited on
Commit
67098ed
·
1 Parent(s): c3cfc70

Add application file

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import pickle
6
+ from tree_decision import Node
7
+ # Load cây đã huấn luyện
8
+ with open("tree.pkl", "rb") as f:
9
+ tree, feature_names = pickle.load(f)
10
+
11
+ # Hàm convert dòng dữ liệu thành dict
12
+ def data_dict(row):
13
+ return {feature_names[i]: row[i] for i in range(len(feature_names))}
14
+
15
+ # Hàm xử lý CSV
16
+ def predict_csv(file):
17
+ df = pd.read_csv(file.name)
18
+
19
+ # Kiểm tra cột
20
+ for name in feature_names:
21
+ if name not in df.columns:
22
+ return f"❌ CSV must contain column '{name}'"
23
+
24
+ # Predict từng dòng
25
+ predictions = []
26
+ for i, row in df[feature_names].iterrows():
27
+ row_dict = data_dict(row.values)
28
+ pred = tree.search(row_dict)
29
+ predictions.append(pred)
30
+
31
+ df['prediction'] = predictions
32
+ return df
33
+
34
+ # Giao diện Gradio
35
+ demo = gr.Interface(
36
+ fn=predict_csv,
37
+ inputs=gr.File(label="Upload CSV"),
38
+ outputs=gr.Dataframe(label="Predictions")
39
+ )
40
+
41
+ if __name__ == "__main__":
42
+ demo.launch(share=True)