ramy2018 commited on
Commit
7d3c4f0
·
verified ·
1 Parent(s): 110e6e2

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +8 -4
rag_pipeline.py CHANGED
@@ -11,8 +11,9 @@ class RAGPipeline:
11
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
12
  self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
13
 
14
- self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
15
- self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
 
16
 
17
  self.chunks = []
18
  self.embeddings = None
@@ -32,11 +33,14 @@ class RAGPipeline:
32
  return [self.chunks[i] for i in top_indices]
33
 
34
  def summarize_text(self, text):
35
- prompt = f"لخص النص التالي:\n{text}"
 
36
  try:
37
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
38
  summary_ids = self.model.generate(inputs["input_ids"], max_length=128)
39
- return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
 
 
40
  except Exception as e:
41
  print(f"[RAG][ERROR] أثناء التلخيص: {e}")
42
  return ""
 
11
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
12
  self.embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
13
 
14
+ # نموذج مخصص للتلخيص العربي
15
+ self.tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
16
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
17
 
18
  self.chunks = []
19
  self.embeddings = None
 
33
  return [self.chunks[i] for i in top_indices]
34
 
35
  def summarize_text(self, text):
36
+ print("[RAG][INPUT TO SUMMARIZE]:", text)
37
+ prompt = f"summarize: {text}"
38
  try:
39
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
40
  summary_ids = self.model.generate(inputs["input_ids"], max_length=128)
41
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
42
+ print(f"[RAG][DEBUG] الملخص الناتج:\n{summary}")
43
+ return summary
44
  except Exception as e:
45
  print(f"[RAG][ERROR] أثناء التلخيص: {e}")
46
  return ""