import faiss import numpy as np def save_faiss_embeddings_index(embeddings, file_name): # Ensure embeddings are in float32 format if not isinstance(embeddings, np.ndarray): embeddings = embeddings.numpy() embeddings = embeddings.astype('float32') # Create a FAISS index index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 distance index.add(embeddings) # Save the FAISS index faiss.write_index(index, file_name) def load_faiss_index(index_path): index = faiss.read_index(index_path) return index def normalize_embeddings(embeddings): # Normalize embeddings embeddings = embeddings / np.linalg.norm(embeddings, axis=1)[:, None] return embeddings def search_faiss_index(index, query_embedding, k=5): # Perform similarity search D, I = index.search(query_embedding, k) # D: distances, I: indices return D, I def Z_load_embeddings_and_index(file_name): # Load embeddings from .npy file embeddings = np.load(f"{file_name}_embeddings.npy") # Load FAISS index from .index file index = faiss.read_index(file_name) return embeddings, index