Rag_proj / app.py
sgt444pepper's picture
init!
6258aee verified
raw
history blame
4.54 kB
import gradio as gr
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from litellm import completion
import os
import torch
from sentence_transformers import CrossEncoder
# --- 1. Завантаження документів ---
def load_documents(file_paths):
documents = []
for path in file_paths:
with open(path, 'r', encoding='utf-8') as file:
documents.append(file.read().strip())
return documents
# --- 2. Індексування документів ---
class DocumentIndexer:
def __init__(self, documents):
self.documents = documents
self.bm25 = BM25Okapi([doc.split() for doc in documents])
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.index = self.create_faiss_index()
def create_faiss_index(self):
embeddings = self.embed_documents(self.documents)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
def embed_documents(self, docs):
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
return embeddings
def search_bm25(self, query, top_k=5):
query_terms = query.split()
scores = self.bm25.get_scores(query_terms)
top_indices = np.argsort(scores)[::-1][:top_k]
return [self.documents[i] for i in top_indices]
def search_semantic(self, query, top_k=5):
query_embedding = self.embed_documents([query])
distances, indices = self.index.search(query_embedding, top_k)
return [self.documents[i] for i in indices[0]]
# --- 3. Ререйкер ---
class Reranker:
def __init__(self, model_name="cross-encoder/ms-marco-TinyBERT-L-6"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = CrossEncoder(model_name)
def rank(self, query, documents):
pairs = [(query, doc) for doc in documents]
scores = self.model.predict(pairs)
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]]
return ranked_docs
# --- 4. Генерація відповіді ---
class QAChatbot:
def __init__(self, indexer, reranker):
self.indexer = indexer
self.reranker = reranker
def generate_answer(self, query):
# 1. Шукаємо релевантні документи
bm25_results = self.indexer.search_bm25(query)
semantic_results = self.indexer.search_semantic(query)
combined_results = list(set(bm25_results + semantic_results))
# 2. Ранжуємо документи
ranked_docs = self.reranker.rank(query, combined_results)
# 3. Генеруємо відповідь
context = "\n".join(ranked_docs[:3]) # Використовуємо топ-3 документи
response = completion(
model="groq/llama3-8b-8192",
messages=[
{
"role": "system",
"content": PROMPT
},
{
"role": "user",
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:",
}
],
)
return response
# --- 5. Створення Gradio інтерфейсу ---
def chatbot_interface(query):
file_paths = ["company.txt", "Base.txt"] # Вкажіть ваші файли
documents = load_documents(file_paths)
# Налаштовуємо індексер та ререйкер
indexer = DocumentIndexer(documents)
reranker = Reranker()
# Запускаємо чат-бота
chatbot = QAChatbot(indexer, reranker)
answer = chatbot.generate_answer(query)
return answer["choices"][0]["message"]["content"]
# Створення інтерфейсу Gradio
iface = gr.Interface(fn=chatbot_interface, inputs="text", outputs="text",
live=True, title="Чат-бот для ритейл-компанії",
description="Запитуйте мене про товари і я допоможу вам вибрати найкраще!")
# Запуск інтерфейсу
if __name__ == "__main__":
iface.launch()