Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from transformers import AutoTokenizer, AutoModel | |
| from litellm import completion | |
| import os | |
| import torch | |
| from sentence_transformers import CrossEncoder | |
| os.environ['GROQ_API_KEY'] = "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" | |
| PROMPT = """/ | |
| You are a virtual representative of a retail company and a consultant for customers. | |
| To generate answers, use only information from the context! | |
| Do not ask additional questions, but simply offer the product available in the context! | |
| Your goal is to answer customers' questions, thus helping them. | |
| You should advise the customer in choosing products using the context. | |
| If you could not find a specific answer: | |
| - Answer "I do not know. For more information, please contact: +380954673526" and nothing more. | |
| You always maintain a polite, professional tone. The format of the answer should be simple, understandable and clear. Avoid long explanations if they are not necessary. | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
| model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
| reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6") | |
| def load_documents(file_paths): | |
| documents = [] | |
| for path in file_paths: | |
| with open(path, 'r', encoding='utf-8') as file: | |
| documents.append(file.read().strip()) | |
| return documents | |
| def load_documents_with_chunking(file_paths, chunk_size=500): | |
| documents = [] | |
| for path in file_paths: | |
| with open(path, 'r', encoding='utf-8') as file: | |
| text = file.read().strip() | |
| for i in range(0, len(text), chunk_size): | |
| chunk = text[i:i + chunk_size] | |
| documents.append(chunk) | |
| return documents | |
| class Retriver: | |
| def __init__(self, documents, tokenizer, model): | |
| self.documents = documents | |
| self.bm25 = BM25Okapi([doc.split() for doc in documents]) | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.index = self.create_faiss_index() | |
| def create_faiss_index(self): | |
| embeddings = self.embed_documents(self.documents) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| return index | |
| def embed_documents(self, docs): | |
| tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy() | |
| return embeddings | |
| def search_bm25(self, query, top_k=5): | |
| query_terms = query.split() | |
| scores = self.bm25.get_scores(query_terms) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| return [self.documents[i] for i in top_indices] | |
| def search_semantic(self, query, top_k=5): | |
| query_embedding = self.embed_documents([query]) | |
| distances, indices = self.index.search(query_embedding, top_k) | |
| return [self.documents[i] for i in indices[0]] | |
| class Reranker: | |
| def __init__(self, reranker): | |
| self.model = reranker | |
| def rank(self, query, documents): | |
| pairs = [(query, doc) for doc in documents] | |
| scores = self.model.predict(pairs) | |
| ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]] | |
| return ranked_docs | |
| class QAChatbot: | |
| def __init__(self, indexer, reranker): | |
| self.indexer = indexer | |
| self.reranker = reranker | |
| def generate_answer(self, query): | |
| bm25_results = self.indexer.search_bm25(query) | |
| semantic_results = self.indexer.search_semantic(query) | |
| combined_results = list(set(bm25_results + semantic_results)) | |
| ranked_docs = self.reranker.rank(query, combined_results) | |
| context = "\n".join(ranked_docs[:3]) | |
| response = completion( | |
| model="groq/llama3-8b-8192", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": PROMPT | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Context: {context}\n\nQuestion: {query}\nAnswer:", | |
| } | |
| ], | |
| ) | |
| return response | |
| def chatbot_interface(query, history): | |
| # file_paths = ["Company_eng.txt", "base_eng.txt"] | |
| # documents = load_documents(file_paths) | |
| # indexer = Retriver(documents, tokenizer, model) | |
| # reranker = Reranker(reranker_model) | |
| #chatbot = QAChatbot(indexer, reranker) | |
| answer = chatbot.generate_answer(query) | |
| return answer["choices"][0]["message"]["content"] | |
| iface = gr.ChatInterface(fn=chatbot_interface, type="messages") | |
| if __name__ == "__main__": | |
| file_paths = ["Company_eng.txt", "base_eng.txt"] | |
| documents = load_documents(file_paths) | |
| indexer = Retriver(documents, tokenizer, model) | |
| reranker = Reranker(reranker_model) | |
| chatbot = QAChatbot(indexer, reranker) | |
| iface.launch() | |