import pprint import numpy as np import gradio as gr import pandas as pd from rank_bm25 import * from clean_data import text_normalizer from sentence_transformers import SentenceTransformer, util # read data df = pd.read_csv('./assets/final_combined.csv')[['category', 'brand', 'product_name']].to_dict(orient='records') doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickle=True) tokenized_corpus = np.load('./assets/final_tokenized_corpus.npy', allow_pickle=True) # Semantic Search model semantic_search = SentenceTransformer("intfloat/multilingual-e5-base", cache_folder = "./assets") # full-text search model keyword_search = BM25Okapi(tokenized_corpus) def full_text_search(normalized_query): return None tokenized_query = normalized_query.split() ft_scores = keyword_search.get_scores(tokenized_query) if max(ft_scores) == 0.0: return None doc_to_fts = [{'corpus_id': doc, 'score': 2 / np.pi * np.arctan(score)} for doc, score in zip(range(len(df)), ft_scores)] doc_to_fts = sorted(doc_to_fts, key = lambda x:x['score'], reverse = True) return doc_to_fts[:20] def semantic_rerank(normalized_query, doc_to_fts): query_embedding = semantic_search.encode("query: " + normalized_query) if doc_to_fts != None: rerank_emded, best_K_results = list(), list() for item in doc_to_fts: best_K_results.append(item['corpus_id']) rerank_emded.append(doc_embeddings[item['corpus_id']]) else: rerank_emded, best_K_results = doc_embeddings, range(len(doc_embeddings)) rr_scores = util.dot_score(query_embedding, rerank_emded) doc_to_rr = [{'corpus_id': doc, 'score': score.numpy()} for doc, score in zip(best_K_results, rr_scores[0])] doc_to_rr = sorted(doc_to_rr, key = lambda x:x['score'], reverse = True) return doc_to_rr def print_results(hits, k_items): results = "" for hit in hits[:k_items]: results += pprint.pformat(df[hit['corpus_id']], indent=4) + "\n" return results def predict(query): normalized_query = text_normalizer(query) bm25_hits = full_text_search(normalized_query) rr_hits = semantic_rerank(normalized_query, bm25_hits) return print_results(rr_hits, k_items=10) app = gr.Interface( fn = predict, inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."), outputs = "text", title = "BM25 + multilingual-e5-base" ) app.launch()