File size: 3,928 Bytes
9bfc86c
1b04b96
 
d346441
99afa50
 
 
 
 
 
 
1b04b96
 
 
 
 
99afa50
1b04b96
 
 
99afa50
 
1b04b96
43b460f
d346441
99afa50
 
1b04b96
d346441
 
1b04b96
99afa50
 
 
 
 
d346441
99afa50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d346441
99afa50
 
 
1b04b96
99afa50
d346441
1b04b96
 
 
99afa50
 
 
1b04b96
99afa50
1b04b96
43b460f
d346441
58a211a
3fd6562
58a211a
 
 
43b460f
3fd6562
43b460f
99afa50
a523549
99afa50
d346441
 
 
 
99afa50
1b04b96
99afa50
 
 
d346441
 
99afa50
d346441
 
 
 
99afa50
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name="all-MiniLM-L12-v2",
    model_kwargs={"device": device}
)

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

all_documents = []
ragbench = {}
index = None  
chunk_docs = []
documents = [] 

# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)

# Initialize a text splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1024,
    chunk_overlap=100
)

def chunk_documents(docs):
    chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
    return chunks

def create_faiss_index(dataset):
    # Load dataset
    ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)

    for split in ragbench_dataset.keys():
        for row in ragbench_dataset[split]:
            # Ensure document is a string before appending
            doc = row["documents"]
            if isinstance(doc, list):
                # If doc is a list, join its elements into a single string
                doc = " ".join(doc)
            documents.append(doc)  # Extract document text
            # Chunking

    chunked_documents = chunk_documents(documents)

    # Save documents in JSON (metadata storage)
    with open(f"{dataset}_chunked_docs.json", "w") as f:
        json.dump(chunked_documents, f)

    print(len(chunked_documents))
    # Convert to embeddings
    embeddings = embedding_model.embed_documents(chunked_documents)

    # Convert embeddings to a NumPy array
    embeddings_np = np.array(embeddings, dtype=np.float32)


    # Save FAISS index
    index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32)  # 32 is the graph size
    index.add(embeddings_np)
    faiss.write_index(index, f"{dataset}_chunked_index.faiss")

    print(f"{dataset} stored as individual FAISS index!")

def load_ragbench():
    global ragbench
    if ragbench:
        return ragbench  
    datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                    'tatqa', 'techqa']
    for dataset in datasets:
        ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
    return ragbench

def load_faiss(query_dataset):
    global index
    faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
    if os.path.exists(faiss_index_path):
        index = faiss.read_index(faiss_index_path)
        print("FAISS index loaded successfully.")
    else:
        print("FAISS index file not found. Run create_faiss_index_file() first.") 

def load_chunks(query_dataset):
    global chunk_docs
    metadata_path = f"data_local/{query_dataset}_chunked_docs.json"
    if os.path.exists(metadata_path):
        with open(metadata_path, "r") as f:
            chunk_docs = json.load(f)
        print("Metadata loaded successfully.")
    else:
        print("Metadata file not found. Run create_faiss_index_file() first.")

def load_data_from_faiss(query_dataset):
    load_faiss(query_dataset)
    load_chunks(query_dataset)
    #return index_, chunks_  

def rerank_documents(query, retrieved_docs):
    doc_texts = [doc for doc in retrieved_docs]
    scores = reranker.predict([[query, doc] for doc in doc_texts])
    ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
    return ranked_docs[:5]  # Return top 5 most relevant