Update app.py
Browse files
app.py
CHANGED
@@ -1,66 +1,67 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from transformers import pipeline
|
3 |
import re
|
|
|
|
|
4 |
|
5 |
-
# Load a
|
6 |
-
classifier = pipeline("
|
7 |
-
|
8 |
-
# Define categories for classification
|
9 |
-
categories = ["Saving", "Need", "Want", "Investment"]
|
10 |
-
|
11 |
-
# Helper function to extract information from input
|
12 |
-
def extract_info(user_input):
|
13 |
-
# Remove Vietnamese accents (khong dau processing)
|
14 |
-
normalized_input = re.sub(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', user_input)
|
15 |
-
normalized_input = normalized_input.replace("đ", "d")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
-
#
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
# Extract sub-category
|
47 |
-
sub_category = input_text.lower()
|
48 |
-
return {
|
49 |
-
"classification": classification,
|
50 |
-
"amount": amount_match.group(0),
|
51 |
-
"sub_category": sub_category
|
52 |
-
}
|
53 |
|
54 |
-
# Define
|
55 |
def process_user_input(user_input):
|
56 |
-
|
|
|
57 |
|
58 |
iface = gr.Interface(
|
59 |
fn=process_user_input,
|
60 |
inputs="text",
|
61 |
outputs="text",
|
62 |
-
title="Expenditure
|
63 |
-
description="Classify expenditures into Need, Want, Saving
|
64 |
)
|
65 |
|
66 |
iface.launch()
|
|
|
|
|
|
|
1 |
import re
|
2 |
+
from transformers import pipeline
|
3 |
+
import gradio as gr
|
4 |
|
5 |
+
# Load a lightweight model for classification
|
6 |
+
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", return_all_scores=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
# Define main and subcategories
|
9 |
+
CATEGORIES = {
|
10 |
+
"Need": [
|
11 |
+
"utilities", "housing", "groceries", "transportation", "education", "medical", "insurance", "childcare"
|
12 |
+
],
|
13 |
+
"Want": [
|
14 |
+
"dining out", "entertainment", "travel", "fitness", "shopping", "hobbies", "personal care"
|
15 |
+
],
|
16 |
+
"Saving/Investment": [
|
17 |
+
"emergency fund", "retirement", "investments", "debt repayment", "education fund", "savings for goals", "health savings"
|
18 |
+
]
|
19 |
+
}
|
20 |
|
21 |
+
# Predefined keywords for fast classification
|
22 |
+
KEYWORDS = {
|
23 |
+
"saving": ["gui tiet kiem", "tiet kiem", "lai suat", "savings", "interest"],
|
24 |
+
"utilities": ["electricity", "water", "gas", "internet", "phone"],
|
25 |
+
"housing": ["rent", "mortgage", "property tax", "maintenance"],
|
26 |
+
"groceries": ["food", "beverages", "supermarket"],
|
27 |
+
"transportation": ["gas", "car", "vehicle", "public transit"],
|
28 |
+
"education": ["tuition", "books", "school", "course"],
|
29 |
+
"medical": ["insurance", "doctor", "prescriptions", "medicine"],
|
30 |
+
"dining out": ["restaurant", "cafe", "fast food", "delivery"],
|
31 |
+
# Add more keywords for all subcategories...
|
32 |
+
}
|
33 |
|
34 |
+
# Normalize Vietnamese input (remove accents)
|
35 |
+
def normalize_vietnamese(text):
|
36 |
+
return re.sub(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', text).replace("đ", "d")
|
37 |
|
38 |
+
# Classify input into main and subcategories
|
39 |
+
def classify_input(user_input):
|
40 |
+
# Normalize input
|
41 |
+
normalized_input = normalize_vietnamese(user_input.lower())
|
42 |
|
43 |
+
# Match keywords for faster classification
|
44 |
+
for main_cat, subcats in CATEGORIES.items():
|
45 |
+
for subcat in subcats:
|
46 |
+
if any(keyword in normalized_input for keyword in KEYWORDS.get(subcat, [])):
|
47 |
+
return {"Main Category": main_cat, "Sub Category": subcat.capitalize()}
|
48 |
|
49 |
+
# Fallback to model classification
|
50 |
+
result = classifier(normalized_input)
|
51 |
+
category = max(result, key=lambda x: x["score"])["label"]
|
52 |
+
return {"Main Category": category, "Sub Category": "Unknown"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# Define Gradio interface
|
55 |
def process_user_input(user_input):
|
56 |
+
classification = classify_input(user_input)
|
57 |
+
return f"Main Category: {classification['Main Category']}\nSub Category: {classification['Sub Category']}"
|
58 |
|
59 |
iface = gr.Interface(
|
60 |
fn=process_user_input,
|
61 |
inputs="text",
|
62 |
outputs="text",
|
63 |
+
title="Expenditure Classifier",
|
64 |
+
description="Classify expenditures into main and subcategories (Need, Want, Saving/Investment)."
|
65 |
)
|
66 |
|
67 |
iface.launch()
|