from sentence_transformers import CrossEncoder from sentence_transformers import SentenceTransformer import faiss import json class CustomRetriever: def __init__(self, chunks_path, embeddings_path, metadata_path, top_k=50): self.model_bi = SentenceTransformer("deepvk/USER-bge-m3") self.model_cross = CrossEncoder("DiTy/cross-encoder-russian-msmarco") with open(chunks_path, "r") as f: self.chunks = json.load(f) self.index = faiss.read_index(embeddings_path) self.top_k = top_k with open(metadata_path, "r") as f: self.metadata = json.load(f) def retrieve(self, query): query_vector = self.model_bi.encode([query]) faiss.normalize_L2(query_vector) distances, indices = self.index.search(query_vector, self.top_k) possible_answers = list() for i in range(len(indices[0])): possible_answers.append(self.chunks[indices[0][i]]) s = self.model_cross.rank(query, possible_answers) context = '' for i in range(5): meta = self.metadata[str(indices[0][s[i]["corpus_id"]])] context += f"Факт {str(i + 1)}: {possible_answers[s[i]['corpus_id']]}. Источник:\nкнига - {meta['book']}\nномер статьи - {meta['article_num']}\nссылка на книгу - {meta['link']}\n" return context