Spaces:
Sleeping
Sleeping
import gradio as gr | |
import faiss | |
import numpy as np | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModel | |
from litellm import completion | |
import os | |
import torch | |
from sentence_transformers import CrossEncoder | |
# --- 1. Завантаження документів --- | |
def load_documents(file_paths): | |
documents = [] | |
for path in file_paths: | |
with open(path, 'r', encoding='utf-8') as file: | |
documents.append(file.read().strip()) | |
return documents | |
# --- 2. Індексування документів --- | |
class DocumentIndexer: | |
def __init__(self, documents): | |
self.documents = documents | |
self.bm25 = BM25Okapi([doc.split() for doc in documents]) | |
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
self.index = self.create_faiss_index() | |
def create_faiss_index(self): | |
embeddings = self.embed_documents(self.documents) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
def embed_documents(self, docs): | |
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy() | |
return embeddings | |
def search_bm25(self, query, top_k=5): | |
query_terms = query.split() | |
scores = self.bm25.get_scores(query_terms) | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
return [self.documents[i] for i in top_indices] | |
def search_semantic(self, query, top_k=5): | |
query_embedding = self.embed_documents([query]) | |
distances, indices = self.index.search(query_embedding, top_k) | |
return [self.documents[i] for i in indices[0]] | |
# --- 3. Ререйкер --- | |
class Reranker: | |
def __init__(self, model_name="cross-encoder/ms-marco-TinyBERT-L-6"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = CrossEncoder(model_name) | |
def rank(self, query, documents): | |
pairs = [(query, doc) for doc in documents] | |
scores = self.model.predict(pairs) | |
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]] | |
return ranked_docs | |
# --- 4. Генерація відповіді --- | |
class QAChatbot: | |
def __init__(self, indexer, reranker): | |
self.indexer = indexer | |
self.reranker = reranker | |
def generate_answer(self, query): | |
# 1. Шукаємо релевантні документи | |
bm25_results = self.indexer.search_bm25(query) | |
semantic_results = self.indexer.search_semantic(query) | |
combined_results = list(set(bm25_results + semantic_results)) | |
# 2. Ранжуємо документи | |
ranked_docs = self.reranker.rank(query, combined_results) | |
# 3. Генеруємо відповідь | |
context = "\n".join(ranked_docs[:3]) # Використовуємо топ-3 документи | |
response = completion( | |
model="groq/llama3-8b-8192", | |
messages=[ | |
{ | |
"role": "system", | |
"content": PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:", | |
} | |
], | |
) | |
return response | |
# --- 5. Створення Gradio інтерфейсу --- | |
def chatbot_interface(query): | |
file_paths = ["company.txt", "Base.txt"] # Вкажіть ваші файли | |
documents = load_documents(file_paths) | |
# Налаштовуємо індексер та ререйкер | |
indexer = DocumentIndexer(documents) | |
reranker = Reranker() | |
# Запускаємо чат-бота | |
chatbot = QAChatbot(indexer, reranker) | |
answer = chatbot.generate_answer(query) | |
return answer["choices"][0]["message"]["content"] | |
# Створення інтерфейсу Gradio | |
iface = gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", | |
live=True, title="Чат-бот для ритейл-компанії", | |
description="Запитуйте мене про товари і я допоможу вам вибрати найкраще!") | |
# Запуск інтерфейсу | |
if __name__ == "__main__": | |
iface.launch() | |