dnzblgn's picture
Update app.py
acb03a2 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
import os
import time
import logging
from huggingface_hub import InferenceClient
from langchain_core.language_models.llms import LLM
from typing import Optional, List, Mapping, Any
from pydantic import Field
from pydantic import PrivateAttr
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load tokenizer and classification models
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews")
sentiment_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews")
sarcasm_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews")
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews")
doc_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification")
doc_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
label_mapping = {
"shipping_and_delivery": 0,
"customer_service": 1,
"price_and_value": 2,
"quality_and_performance": 3,
"use_and_design": 4,
"other": 5
}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
def get_hf_token():
return os.environ.get("HF") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
def analyze_reviews(reviews):
analysis = {
"overall": {"positive": 0, "negative": 0},
"categories": {label: {"positive": 0, "negative": 0} for label in label_mapping}
}
for review in reviews:
sentiment_inputs = sentiment_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512)
sentiment_class = torch.argmax(sentiment_model(**sentiment_inputs).logits, dim=-1).item()
sentiment = "positive" if sentiment_class == 0 else "negative"
if sentiment == "positive":
sarcasm_inputs = sarcasm_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512)
if torch.argmax(sarcasm_model(**sarcasm_inputs).logits, dim=-1).item() == 1:
sentiment = "negative"
doc_inputs = doc_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512)
category_class = torch.argmax(doc_model(**doc_inputs).logits, dim=-1).item()
category = reverse_label_mapping[category_class]
analysis["overall"][sentiment] += 1
analysis["categories"][category][sentiment] += 1
return analysis
def generate_analysis_document(analysis):
total = analysis['overall']['positive'] + analysis['overall']['negative']
doc = [f"Total Reviews: {total}", f"Positive: {analysis['overall']['positive']}", f"Negative: {analysis['overall']['negative']}"]
for cat, val in analysis['categories'].items():
total_cat = val['positive'] + val['negative']
if total_cat:
doc.append(f"\n{cat.title()} => P: {val['positive']} / N: {val['negative']}")
return "\n".join(doc)
def create_db_from_analysis(doc):
chunks = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200).create_documents([doc])
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(chunks, embeddings)
class FallbackLLM(LLM):
token: str
temperature: float = 0.7
max_new_tokens: int = 512
model_list: List[str] = Field(default=[
"mistralai/Mistral-7B-Instruct-v0.3",
"HuggingFaceH4/zephyr-7b-beta",
"tiiuae/falcon-7b-instruct"
])
_client: Optional[InferenceClient] = PrivateAttr()
_model_id: Optional[str] = PrivateAttr()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._client = None
self._model_id = None
for model in self.model_list:
try:
client = InferenceClient(model=model, token=self.token)
test = client.text_generation("Hello", max_new_tokens=5, return_full_text=False)
self._client = client
self._model_id = model
break
except Exception as e:
logger.warning(f"Model {model} failed. Trying next...")
if not self._client:
raise RuntimeError("No fallback LLMs succeeded.")
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
wrapped_prompt = f"<s>[INST] {prompt} [/INST]"
result = self._client.text_generation(
wrapped_prompt,
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
return_full_text=False
)
return result.strip()
@property
def _llm_type(self):
return "fallback_llm"
@property
def _identifying_params(self):
return {"token": self.token, "model_id": self._model_id or "unknown"}
def initialize_rag_chatbot(db):
token = get_hf_token()
if not token:
return None
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
retriever = db.as_retriever(search_kwargs={"k": 4})
llm = FallbackLLM(token=token)
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, return_source_documents=True)
def process_and_initialize(file):
if file is None:
return None, None, "Please upload a file."
with open(file, 'r', encoding='utf-8') as f:
reviews = [line.strip() for line in f if line.strip()]
analysis = analyze_reviews(reviews)
doc = generate_analysis_document(analysis)
db = create_db_from_analysis(doc)
rag_chain = initialize_rag_chatbot(db)
return db, rag_chain, "Processing complete. You can now chat."
def user_query_with_rag(query, qa_chain, chatbot):
history = chatbot or []
if not query.strip():
yield history, ""
return
if qa_chain is None:
history.append((query, "Please upload and process a file first."))
yield history, ""
return
# Add user's question and placeholder for assistant's answer
history.append((query, ""))
yield history, ""
try:
response = qa_chain.invoke({"question": query, "chat_history": []})
assistant_response = response.get("answer", "No answer generated.")
# Typing effect
for i in range(len(assistant_response)):
history[-1] = (query, assistant_response[:i + 1])
yield history, ""
time.sleep(0.01)
except Exception as e:
logger.error(f"RAG error: {e}")
history[-1] = (query, "An error occurred while answering your question.")
yield history, ""
def demo():
with gr.Blocks(title="RAG Analyzer") as app:
db_state = gr.State(None)
chain_state = gr.State(None)
gr.Markdown("# 🧠 Customer Review Analyzer with Fallback RAG")
file_input = gr.File(label="Upload review file (.txt)", type="filepath")
status = gr.Textbox(label="Status")
chatbot = gr.Chatbot(label="Chatbot", height=400)
user_input = gr.Textbox(placeholder="Ask about the reviews...", show_label=False)
submit_btn = gr.Button("Send")
process_btn = gr.Button("Process Reviews")
process_btn.click(process_and_initialize, inputs=[file_input], outputs=[db_state, chain_state, status])
submit_btn.click(user_query_with_rag, inputs=[user_input, chain_state, chatbot], outputs=[chatbot, user_input])
user_input.submit(user_query_with_rag, inputs=[user_input, chain_state, chatbot], outputs=[chatbot, user_input])
return app
if __name__ == "__main__":
demo().launch(server_name="0.0.0.0", server_port=7860, share=False)