|
import gradio as gr |
|
import torch |
|
import faiss |
|
import os |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification |
|
|
|
|
|
sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews" |
|
sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews" |
|
doc = "dnzblgn/Customer-Reviews-Classification" |
|
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False) |
|
sentiment_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews") |
|
|
|
sarcasm_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews", use_fast=False) |
|
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews") |
|
|
|
classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False) |
|
classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification") |
|
|
|
|
|
mistral_model_name = "mistralai/Mistral-7B-v0.1" |
|
causal_tokenizer = AutoTokenizer.from_pretrained(mistral_model_name) |
|
causal_model = AutoModelForCausalLM.from_pretrained(mistral_model_name, torch_dtype=torch.float16).eval() |
|
|
|
|
|
UPLOAD_FOLDER = "uploads" |
|
SUMMARY_FILE = "summary.txt" |
|
FAISS_INDEX_PATH = "faiss_index" |
|
DOCUMENTS_FILE = "documents.txt" |
|
|
|
if not os.path.exists(UPLOAD_FOLDER): |
|
os.makedirs(UPLOAD_FOLDER) |
|
|
|
categories = { |
|
0: "Shipping and Delivery", |
|
1: "Customer Service", |
|
2: "Price and Value", |
|
3: "Quality and Performance", |
|
4: "Use and Design", |
|
5: "Other" |
|
} |
|
|
|
|
|
def analyze_sentiment(sentence): |
|
inputs = sentiment_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = sentiment_model(**inputs) |
|
logits = outputs.logits |
|
sentiment = torch.argmax(logits, dim=-1).item() |
|
return "Positive" if sentiment == 0 else "Negative" |
|
|
|
def detect_sarcasm(sentence): |
|
inputs = sarcasm_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = sarcasm_model(**inputs) |
|
logits = outputs.logits |
|
sarcasm = torch.argmax(logits, dim=-1).item() |
|
return sarcasm == 1 |
|
|
|
def classify_document(sentence): |
|
inputs = classification_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = classification_model(**inputs) |
|
logits = outputs.logits |
|
category = torch.argmax(logits, dim=-1).item() |
|
return categories[category] |
|
|
|
def preprocess_summary(file_path): |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
lines = file.readlines() |
|
|
|
chunks = [] |
|
current_chunk = [] |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
if line.endswith(":") and current_chunk: |
|
chunks.append("\n".join(current_chunk)) |
|
current_chunk = [] |
|
current_chunk.append(line) |
|
|
|
if current_chunk: |
|
chunks.append("\n".join(current_chunk)) |
|
|
|
return chunks |
|
|
|
def create_faiss_index(chunks): |
|
embeddings = [embedding_model.encode(chunk, normalize_embeddings=True) for chunk in chunks] |
|
embeddings_np = np.array(embeddings) |
|
embedding_dimension = embeddings_np.shape[1] |
|
|
|
faiss_index = faiss.IndexFlatL2(embedding_dimension) |
|
faiss_index.add(embeddings_np) |
|
faiss.write_index(faiss_index, FAISS_INDEX_PATH) |
|
|
|
with open(DOCUMENTS_FILE, "w", encoding="utf-8") as doc_file: |
|
for chunk in chunks: |
|
doc_file.write(chunk + "\n--END--\n") |
|
|
|
def handle_uploaded_file(file): |
|
|
|
file_path = os.path.join(UPLOAD_FOLDER, "uploaded_comments.txt") |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
f.write(file) |
|
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
comments = f.readlines() |
|
|
|
results = [] |
|
for comment in comments: |
|
comment = comment.strip() |
|
if not comment: |
|
continue |
|
sentiment = analyze_sentiment(comment) |
|
if sentiment == "Positive" and detect_sarcasm(comment): |
|
sentiment = "Negative" |
|
category = classify_document(comment) |
|
results.append({"comment": comment, "sentiment": sentiment, "category": category}) |
|
|
|
chunks = preprocess_summary(file_path) |
|
create_faiss_index(chunks) |
|
|
|
return "File uploaded and processed successfully." |
|
|
|
def causal_generate_response(prompt): |
|
inputs = causal_tokenizer(prompt, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = causal_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7) |
|
response = causal_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
def query_chatbot(query): |
|
top_k = 5 |
|
faiss_index = faiss.read_index(FAISS_INDEX_PATH) |
|
|
|
with open(DOCUMENTS_FILE, "r", encoding="utf-8") as doc_file: |
|
documents = doc_file.read().split("\n--END--\n") |
|
|
|
query_embedding = embedding_model.encode([query], normalize_embeddings=True) |
|
distances, indices = faiss_index.search(np.array(query_embedding), top_k) |
|
|
|
relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)] |
|
context = "\n\n".join(relevant_docs[:top_k]) |
|
|
|
|
|
final_prompt = ( |
|
f"You are a business data analyst. Analyze the feedback data and identify the overall sentiment trends. " |
|
f"Focus on determining whether positive feedback or negative feedback dominates in each category, and avoid overstating less significant trends. " |
|
f"Provide clear, data-driven insights.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {query}\n\n" |
|
f"Your Answer (based on the data and context):" |
|
) |
|
|
|
return causal_generate_response(final_prompt) |
|
|
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Sentiment Analysis Powered by Sarcasm Detection") |
|
with gr.Row(): |
|
upload = gr.File(label="Upload .txt File") |
|
chatbot_output = gr.Textbox(label="Processing Report", lines=10, interactive=False) |
|
|
|
upload_btn = gr.Button("Process File") |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox(label="Ask a Question") |
|
answer_output = gr.Textbox(label="Answer", lines=5, interactive=False) |
|
|
|
query_btn = gr.Button("Get Answer") |
|
|
|
def process_file_and_show_chatbot(file): |
|
result_message = handle_uploaded_file(file) |
|
return result_message |
|
|
|
upload_btn.click(process_file_and_show_chatbot, inputs=upload, outputs=chatbot_output) |
|
|
|
def handle_query(query): |
|
response = query_chatbot(query) |
|
return response |
|
|
|
query_btn.click(handle_query, inputs=query_input, outputs=answer_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |