Spaces:
Sleeping
Sleeping
File size: 5,280 Bytes
19f9e63 b4cc279 ef38e39 8c417f0 97d5873 050fb16 6097e30 ef38e39 b4cc279 ae53009 b4cc279 727f91b b4cc279 727f91b b4cc279 ef38e39 727f91b b4cc279 ef38e39 b4cc279 ef38e39 3f63efa ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 6097e30 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 6097e30 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 97d5873 050fb16 ef38e39 15f2885 cb5e087 15f2885 97d5873 050fb16 b4cc279 a44a20c a6e32e1 aac2db6 a44a20c ef38e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import torch
from helper_functions import *
from rank_bm25 import BM25L
import time
import pprint
def print_results(results):
result_string = ''
for hit in results:
result_string += pprint.pformat(hit, indent=4) + "\n"
return result_string.strip()
def predict(query):
start_time = time.time()
normalized_query_list = (
[normalizer.clean_text(query)]
)
normalize_query_time = time.time() - start_time
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
request_start_time = time.time()
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
request_end_time = time.time()
request_time = request_end_time - request_start_time
# Translate product representations to English
normalization_start_time = time.time()
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
normalization_end_time = time.time()
normalization_time = normalization_end_time - normalization_start_time
try:
# cateogorize products
categorize_start_time = time.time()
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
categorize_end_time = time.time()
categorize_time = categorize_end_time - categorize_start_time
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
representation_list = tasks
try:
# Tokenize representations for keyword search
tokenization_start_time = time.time()
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
tokenization_end_time = time.time()
tokenization_time = tokenization_end_time - tokenization_start_time
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
encode_start_time = time.time()
doc_embeddings = semantic_model.encode(
representation_list, convert_to_tensor=True
)
encode_end_time = time.time()
encode_time = encode_end_time - encode_start_time
try:
# Calculate interrelations between products
calculate_interrelations_start_time = time.time()
calculate_interrelations(request_json, doc_embeddings)
calculate_interrelations_end_time = time.time()
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
process_time = time.time()
for query in normalized_query_list:
keyword_scores = check_validity(query, keyword_search)
semantic_scores = semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
is_cheapest(query, request_json)
results = rerank_results(request_json, hybrid_scores)
process_end_time = time.time()
process_time_taken = process_end_time - process_time
time_taken = time.time() - start_time
# hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time,
# "request_time": request_time, "normalization_time": normalization_time,
# "categorize_time": categorize_time, "tokenization_time": tokenization_time, "encode_time": encode_time,
# "calculate_interrelations_time": calculate_interrelations_time,
# "process_time": process_time_taken}
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "model name: gte-small, model size: {133MB}, Pipeline Without Translation"
)
app.launch()
|