|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
import os |
|
import time |
|
import logging |
|
from huggingface_hub import InferenceClient |
|
from langchain_core.language_models.llms import LLM |
|
from typing import Optional, List, Mapping, Any |
|
from pydantic import Field |
|
from pydantic import PrivateAttr |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews") |
|
sentiment_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews") |
|
sarcasm_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews") |
|
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews") |
|
doc_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification") |
|
doc_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification") |
|
|
|
label_mapping = { |
|
"shipping_and_delivery": 0, |
|
"customer_service": 1, |
|
"price_and_value": 2, |
|
"quality_and_performance": 3, |
|
"use_and_design": 4, |
|
"other": 5 |
|
} |
|
reverse_label_mapping = {v: k for k, v in label_mapping.items()} |
|
|
|
def get_hf_token(): |
|
return os.environ.get("HF") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
|
|
|
def analyze_reviews(reviews): |
|
analysis = { |
|
"overall": {"positive": 0, "negative": 0}, |
|
"categories": {label: {"positive": 0, "negative": 0} for label in label_mapping} |
|
} |
|
for review in reviews: |
|
sentiment_inputs = sentiment_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
sentiment_class = torch.argmax(sentiment_model(**sentiment_inputs).logits, dim=-1).item() |
|
sentiment = "positive" if sentiment_class == 0 else "negative" |
|
if sentiment == "positive": |
|
sarcasm_inputs = sarcasm_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
if torch.argmax(sarcasm_model(**sarcasm_inputs).logits, dim=-1).item() == 1: |
|
sentiment = "negative" |
|
doc_inputs = doc_tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
category_class = torch.argmax(doc_model(**doc_inputs).logits, dim=-1).item() |
|
category = reverse_label_mapping[category_class] |
|
analysis["overall"][sentiment] += 1 |
|
analysis["categories"][category][sentiment] += 1 |
|
return analysis |
|
|
|
def generate_analysis_document(analysis): |
|
total = analysis['overall']['positive'] + analysis['overall']['negative'] |
|
doc = [f"Total Reviews: {total}", f"Positive: {analysis['overall']['positive']}", f"Negative: {analysis['overall']['negative']}"] |
|
for cat, val in analysis['categories'].items(): |
|
total_cat = val['positive'] + val['negative'] |
|
if total_cat: |
|
doc.append(f"\n{cat.title()} => P: {val['positive']} / N: {val['negative']}") |
|
return "\n".join(doc) |
|
|
|
def create_db_from_analysis(doc): |
|
chunks = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200).create_documents([doc]) |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
return FAISS.from_documents(chunks, embeddings) |
|
|
|
class FallbackLLM(LLM): |
|
token: str |
|
temperature: float = 0.7 |
|
max_new_tokens: int = 512 |
|
model_list: List[str] = Field(default=[ |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
"tiiuae/falcon-7b-instruct" |
|
]) |
|
|
|
_client: Optional[InferenceClient] = PrivateAttr() |
|
_model_id: Optional[str] = PrivateAttr() |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self._client = None |
|
self._model_id = None |
|
|
|
for model in self.model_list: |
|
try: |
|
client = InferenceClient(model=model, token=self.token) |
|
test = client.text_generation("Hello", max_new_tokens=5, return_full_text=False) |
|
self._client = client |
|
self._model_id = model |
|
break |
|
except Exception as e: |
|
logger.warning(f"Model {model} failed. Trying next...") |
|
|
|
if not self._client: |
|
raise RuntimeError("No fallback LLMs succeeded.") |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: |
|
wrapped_prompt = f"<s>[INST] {prompt} [/INST]" |
|
result = self._client.text_generation( |
|
wrapped_prompt, |
|
temperature=self.temperature, |
|
max_new_tokens=self.max_new_tokens, |
|
return_full_text=False |
|
) |
|
return result.strip() |
|
|
|
@property |
|
def _llm_type(self): |
|
return "fallback_llm" |
|
|
|
@property |
|
def _identifying_params(self): |
|
return {"token": self.token, "model_id": self._model_id or "unknown"} |
|
|
|
def initialize_rag_chatbot(db): |
|
token = get_hf_token() |
|
if not token: |
|
return None |
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer") |
|
retriever = db.as_retriever(search_kwargs={"k": 4}) |
|
llm = FallbackLLM(token=token) |
|
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, return_source_documents=True) |
|
|
|
def process_and_initialize(file): |
|
if file is None: |
|
return None, None, "Please upload a file." |
|
with open(file, 'r', encoding='utf-8') as f: |
|
reviews = [line.strip() for line in f if line.strip()] |
|
analysis = analyze_reviews(reviews) |
|
doc = generate_analysis_document(analysis) |
|
db = create_db_from_analysis(doc) |
|
rag_chain = initialize_rag_chatbot(db) |
|
return db, rag_chain, "Processing complete. You can now chat." |
|
|
|
def user_query_with_rag(query, qa_chain, chatbot): |
|
history = chatbot or [] |
|
if not query.strip(): |
|
yield history, "" |
|
return |
|
|
|
if qa_chain is None: |
|
history.append((query, "Please upload and process a file first.")) |
|
yield history, "" |
|
return |
|
|
|
|
|
history.append((query, "")) |
|
yield history, "" |
|
|
|
try: |
|
response = qa_chain.invoke({"question": query, "chat_history": []}) |
|
assistant_response = response.get("answer", "No answer generated.") |
|
|
|
|
|
for i in range(len(assistant_response)): |
|
history[-1] = (query, assistant_response[:i + 1]) |
|
yield history, "" |
|
time.sleep(0.01) |
|
except Exception as e: |
|
logger.error(f"RAG error: {e}") |
|
history[-1] = (query, "An error occurred while answering your question.") |
|
yield history, "" |
|
|
|
|
|
def demo(): |
|
with gr.Blocks(title="RAG Analyzer") as app: |
|
db_state = gr.State(None) |
|
chain_state = gr.State(None) |
|
gr.Markdown("# 🧠 Customer Review Analyzer with Fallback RAG") |
|
file_input = gr.File(label="Upload review file (.txt)", type="filepath") |
|
status = gr.Textbox(label="Status") |
|
chatbot = gr.Chatbot(label="Chatbot", height=400) |
|
user_input = gr.Textbox(placeholder="Ask about the reviews...", show_label=False) |
|
submit_btn = gr.Button("Send") |
|
process_btn = gr.Button("Process Reviews") |
|
|
|
process_btn.click(process_and_initialize, inputs=[file_input], outputs=[db_state, chain_state, status]) |
|
submit_btn.click(user_query_with_rag, inputs=[user_input, chain_state, chatbot], outputs=[chatbot, user_input]) |
|
user_input.submit(user_query_with_rag, inputs=[user_input, chain_state, chatbot], outputs=[chatbot, user_input]) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
demo().launch(server_name="0.0.0.0", server_port=7860, share=False) |