pope30 / rag_pipeline.py
ramy2018's picture
Upload 6 files
fcd494c verified
raw
history blame
1.67 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import numpy as np
import time
class RAGPipeline:
def __init__(self):
print("[RAG] جاري تحميل النموذج والمحول...")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
self.index = None
self.chunks = []
self.chunk_embeddings = []
print("[RAG] تم التحميل بنجاح.")
def build_index(self, chunks, logs=None):
self.chunks = chunks
self.chunk_embeddings = self.embedder.encode(chunks, convert_to_numpy=True)
if logs is not None:
logs.append(f"[RAG] تم بناء الفهرس بأبعاد {self.chunk_embeddings.shape}")
self.index = np.array(self.chunk_embeddings)
def answer(self, question):
question_embedding = self.embedder.encode([question], convert_to_numpy=True)
# بحث عن أقرب 5 مقاطع
similarities = np.dot(self.index, question_embedding.T).squeeze()
top_idx = similarities.argsort()[-5:][::-1]
context = "\n".join([self.chunks[i] for i in top_idx])
inputs = self.tokenizer.encode(question + " " + context, return_tensors="pt", max_length=512, truncation=True)
outputs = self.model.generate(inputs, max_length=200)
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
sources = [self.chunks[i] for i in top_idx]
return answer, sources