|
import faiss |
|
import numpy as np |
|
|
|
def save_faiss_embeddings_index(embeddings, file_name): |
|
|
|
if not isinstance(embeddings, np.ndarray): |
|
embeddings = embeddings.numpy() |
|
embeddings = embeddings.astype('float32') |
|
|
|
|
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(embeddings) |
|
|
|
|
|
faiss.write_index(index, file_name) |
|
|
|
|
|
def load_faiss_index(index_path): |
|
index = faiss.read_index(index_path) |
|
return index |
|
|
|
def normalize_embeddings(embeddings): |
|
|
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1)[:, None] |
|
return embeddings |
|
|
|
def search_faiss_index(index, query_embedding, k=5): |
|
|
|
D, I = index.search(query_embedding, k) |
|
return D, I |
|
|
|
|
|
def Z_load_embeddings_and_index(file_name): |
|
|
|
embeddings = np.load(f"{file_name}_embeddings.npy") |
|
|
|
|
|
index = faiss.read_index(file_name) |
|
|
|
return embeddings, index |
|
|