dnzblgn commited on
Commit
fda81d5
·
verified ·
1 Parent(s): 2119532

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import faiss
4
+ import os
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ # Model paths
10
+ sent = "sent"
11
+ sarc = "sarc"
12
+ doc = "doc"
13
+ embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
14
+
15
+ # Load sentiment, sarcasm, and classification models
16
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(sent)
17
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(sent)
18
+ sarcasm_tokenizer = AutoTokenizer.from_pretrained(sarc)
19
+ sarcasm_model = AutoModelForSequenceClassification.from_pretrained(sarc)
20
+ classification_tokenizer = AutoTokenizer.from_pretrained(doc)
21
+ classification_model = AutoModelForSequenceClassification.from_pretrained(doc)
22
+
23
+ # Load Mistral LLM for conversational answers
24
+ mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
25
+ mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).eval()
26
+
27
+ # Paths and files
28
+ UPLOAD_FOLDER = "uploads"
29
+ SUMMARY_FILE = "summary.txt"
30
+ FAISS_INDEX_PATH = "faiss_index"
31
+ DOCUMENTS_FILE = "documents.txt"
32
+
33
+ if not os.path.exists(UPLOAD_FOLDER):
34
+ os.makedirs(UPLOAD_FOLDER)
35
+
36
+ categories = {
37
+ 0: "Shipping and Delivery",
38
+ 1: "Customer Service",
39
+ 2: "Price and Value",
40
+ 3: "Quality and Performance",
41
+ 4: "Use and Design",
42
+ 5: "Other"
43
+ }
44
+
45
+ # Helper functions
46
+ def analyze_sentiment(sentence):
47
+ inputs = sentiment_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
48
+ with torch.no_grad():
49
+ outputs = sentiment_model(**inputs)
50
+ logits = outputs.logits
51
+ sentiment = torch.argmax(logits, dim=-1).item()
52
+ return "Positive" if sentiment == 0 else "Negative"
53
+
54
+ def detect_sarcasm(sentence):
55
+ inputs = sarcasm_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
56
+ with torch.no_grad():
57
+ outputs = sarcasm_model(**inputs)
58
+ logits = outputs.logits
59
+ sarcasm = torch.argmax(logits, dim=-1).item()
60
+ return sarcasm == 1
61
+
62
+ def classify_document(sentence):
63
+ inputs = classification_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
64
+ with torch.no_grad():
65
+ outputs = classification_model(**inputs)
66
+ logits = outputs.logits
67
+ category = torch.argmax(logits, dim=-1).item()
68
+ return categories[category]
69
+
70
+ def preprocess_summary(file_path):
71
+ with open(file_path, "r", encoding="utf-8") as file:
72
+ lines = file.readlines()
73
+
74
+ chunks = []
75
+ current_chunk = []
76
+
77
+ for line in lines:
78
+ line = line.strip()
79
+ if not line:
80
+ continue
81
+ if line.endswith(":") and current_chunk:
82
+ chunks.append("\n".join(current_chunk))
83
+ current_chunk = []
84
+ current_chunk.append(line)
85
+
86
+ if current_chunk:
87
+ chunks.append("\n".join(current_chunk))
88
+
89
+ return chunks
90
+
91
+ def create_faiss_index(chunks):
92
+ embeddings = [embedding_model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
93
+ embeddings_np = np.array(embeddings)
94
+ embedding_dimension = embeddings_np.shape[1]
95
+
96
+ faiss_index = faiss.IndexFlatL2(embedding_dimension)
97
+ faiss_index.add(embeddings_np)
98
+ faiss.write_index(faiss_index, FAISS_INDEX_PATH)
99
+
100
+ with open(DOCUMENTS_FILE, "w", encoding="utf-8") as doc_file:
101
+ for chunk in chunks:
102
+ doc_file.write(chunk + "\n--END--\n")
103
+
104
+ def handle_uploaded_file(file):
105
+ file_path = os.path.join(UPLOAD_FOLDER, "uploaded_comments.txt")
106
+ file.save(file_path)
107
+
108
+ with open(file_path, "r", encoding="utf-8") as f:
109
+ comments = f.readlines()
110
+
111
+ results = []
112
+ for comment in comments:
113
+ comment = comment.strip()
114
+ if not comment:
115
+ continue
116
+ sentiment = analyze_sentiment(comment)
117
+ if sentiment == "Positive" and detect_sarcasm(comment):
118
+ sentiment = "Negative"
119
+ category = classify_document(comment)
120
+ results.append({"comment": comment, "sentiment": sentiment, "category": category})
121
+
122
+ chunks = preprocess_summary(file_path)
123
+ create_faiss_index(chunks)
124
+
125
+ return "File uploaded and processed successfully."
126
+
127
+ def mistral_generate_response(prompt):
128
+ inputs = mistral_tokenizer(prompt, return_tensors="pt").to("cuda")
129
+ with torch.no_grad():
130
+ outputs = mistral_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7)
131
+ response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
132
+ return response
133
+
134
+ def query_chatbot(query):
135
+ top_k = 5
136
+ faiss_index = faiss.read_index(FAISS_INDEX_PATH)
137
+
138
+ with open(DOCUMENTS_FILE, "r", encoding="utf-8") as doc_file:
139
+ documents = doc_file.read().split("\n--END--\n")
140
+
141
+ query_embedding = embedding_model.encode([query], normalize_embeddings=True)
142
+ distances, indices = faiss_index.search(np.array(query_embedding), top_k)
143
+
144
+ relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
145
+ context = "\n\n".join(relevant_docs[:top_k])
146
+
147
+ final_prompt = (
148
+ f"Context:\n{context}\n\n"
149
+ f"Question: {query}\n\n"
150
+ f"Your Answer (based on the context):"
151
+ )
152
+
153
+ return mistral_generate_response(final_prompt)
154
+
155
+ # Gradio interface
156
+ with gr.Blocks() as interface:
157
+ gr.Markdown("# Sentiment Analysis Powered by Sarcasm Detection")
158
+ with gr.Row():
159
+ upload = gr.File(label="Upload .txt File")
160
+ chatbot_output = gr.Textbox(label="Processing Report", lines=10, interactive=False)
161
+
162
+ upload_btn = gr.Button("Process File")
163
+
164
+ with gr.Row():
165
+ query_input = gr.Textbox(label="Ask a Question")
166
+ answer_output = gr.Textbox(label="Answer", lines=5, interactive=False)
167
+
168
+ query_btn = gr.Button("Get Answer")
169
+
170
+ def process_file_and_show_chatbot(file):
171
+ result_message = handle_uploaded_file(file)
172
+ return result_message
173
+
174
+ upload_btn.click(process_file_and_show_chatbot, inputs=upload, outputs=chatbot_output)
175
+
176
+ def handle_query(query):
177
+ response = query_chatbot(query)
178
+ return response
179
+
180
+ query_btn.click(handle_query, inputs=query_input, outputs=answer_output)
181
+
182
+ # Run Gradio app
183
+ if __name__ == "__main__":
184
+ interface.launch()