Meomap commited on
Commit
0f90513
·
verified ·
1 Parent(s): 26cf865

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -48
app.py CHANGED
@@ -1,66 +1,67 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
  import re
 
 
4
 
5
- # Load a Hugging Face pipeline for zero-shot classification
6
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
-
8
- # Define categories for classification
9
- categories = ["Saving", "Need", "Want", "Investment"]
10
-
11
- # Helper function to extract information from input
12
- def extract_info(user_input):
13
- # Remove Vietnamese accents (khong dau processing)
14
- normalized_input = re.sub(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', user_input)
15
- normalized_input = normalized_input.replace("đ", "d")
16
 
17
- # Predefined rules for classification and information extraction
18
- if "gui tiet kiem" in normalized_input:
19
- return classify_saving(normalized_input)
20
- else:
21
- return classify_other(normalized_input)
 
 
 
 
 
 
 
22
 
23
- # Function to classify saving-related expenditure
24
- def classify_saving(input_text):
25
- details = {}
26
- # Extract potential fields from input text using regex
27
- term_match = re.search(r'(\d+)\s*thang', input_text)
28
- details['term'] = term_match.group(1) + " tháng" if term_match else None
 
 
 
 
 
 
29
 
30
- # Ask for missing fields
31
- if not details.get('term'):
32
- return "Ban gui tiet kiem bao nhieu thang?"
33
 
34
- return "Saving: {}".format(details)
 
 
 
35
 
36
- # Function to classify other expenditures
37
- def classify_other(input_text):
38
- result = classifier(input_text, candidate_labels=categories)
39
- classification = result['labels'][0] # Take the top classification
 
40
 
41
- # Ask for amount and other details if not mentioned
42
- amount_match = re.search(r'(\d+\.?\d*)', input_text)
43
- if not amount_match:
44
- return "Ban chi tieu nay het bao nhieu tien?"
45
-
46
- # Extract sub-category
47
- sub_category = input_text.lower()
48
- return {
49
- "classification": classification,
50
- "amount": amount_match.group(0),
51
- "sub_category": sub_category
52
- }
53
 
54
- # Define the Gradio interface
55
  def process_user_input(user_input):
56
- return extract_info(user_input)
 
57
 
58
  iface = gr.Interface(
59
  fn=process_user_input,
60
  inputs="text",
61
  outputs="text",
62
- title="Expenditure Classification",
63
- description="Classify expenditures into Need, Want, Saving, or Investment based on the 50-30-20 rule. Type in Vietnamese (không dấu)!"
64
  )
65
 
66
  iface.launch()
 
 
 
1
  import re
2
+ from transformers import pipeline
3
+ import gradio as gr
4
 
5
+ # Load a lightweight model for classification
6
+ classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", return_all_scores=True)
 
 
 
 
 
 
 
 
 
7
 
8
+ # Define main and subcategories
9
+ CATEGORIES = {
10
+ "Need": [
11
+ "utilities", "housing", "groceries", "transportation", "education", "medical", "insurance", "childcare"
12
+ ],
13
+ "Want": [
14
+ "dining out", "entertainment", "travel", "fitness", "shopping", "hobbies", "personal care"
15
+ ],
16
+ "Saving/Investment": [
17
+ "emergency fund", "retirement", "investments", "debt repayment", "education fund", "savings for goals", "health savings"
18
+ ]
19
+ }
20
 
21
+ # Predefined keywords for fast classification
22
+ KEYWORDS = {
23
+ "saving": ["gui tiet kiem", "tiet kiem", "lai suat", "savings", "interest"],
24
+ "utilities": ["electricity", "water", "gas", "internet", "phone"],
25
+ "housing": ["rent", "mortgage", "property tax", "maintenance"],
26
+ "groceries": ["food", "beverages", "supermarket"],
27
+ "transportation": ["gas", "car", "vehicle", "public transit"],
28
+ "education": ["tuition", "books", "school", "course"],
29
+ "medical": ["insurance", "doctor", "prescriptions", "medicine"],
30
+ "dining out": ["restaurant", "cafe", "fast food", "delivery"],
31
+ # Add more keywords for all subcategories...
32
+ }
33
 
34
+ # Normalize Vietnamese input (remove accents)
35
+ def normalize_vietnamese(text):
36
+ return re.sub(r'[àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', '', text).replace("đ", "d")
37
 
38
+ # Classify input into main and subcategories
39
+ def classify_input(user_input):
40
+ # Normalize input
41
+ normalized_input = normalize_vietnamese(user_input.lower())
42
 
43
+ # Match keywords for faster classification
44
+ for main_cat, subcats in CATEGORIES.items():
45
+ for subcat in subcats:
46
+ if any(keyword in normalized_input for keyword in KEYWORDS.get(subcat, [])):
47
+ return {"Main Category": main_cat, "Sub Category": subcat.capitalize()}
48
 
49
+ # Fallback to model classification
50
+ result = classifier(normalized_input)
51
+ category = max(result, key=lambda x: x["score"])["label"]
52
+ return {"Main Category": category, "Sub Category": "Unknown"}
 
 
 
 
 
 
 
 
53
 
54
+ # Define Gradio interface
55
  def process_user_input(user_input):
56
+ classification = classify_input(user_input)
57
+ return f"Main Category: {classification['Main Category']}\nSub Category: {classification['Sub Category']}"
58
 
59
  iface = gr.Interface(
60
  fn=process_user_input,
61
  inputs="text",
62
  outputs="text",
63
+ title="Expenditure Classifier",
64
+ description="Classify expenditures into main and subcategories (Need, Want, Saving/Investment)."
65
  )
66
 
67
  iface.launch()