Spaces:
Runtime error
Runtime error
Daniel Doña
commited on
Commit
·
98f5b72
1
Parent(s):
1796e7a
Filter by rank
Browse files- app.py +2 -2
- src/SemanticSearch.py +3 -1
app.py
CHANGED
|
@@ -129,7 +129,7 @@ def semantic_search(message):
|
|
| 129 |
|
| 130 |
gradio.Info("Performing semantic search...", duration=5)
|
| 131 |
|
| 132 |
-
results = extractor.extract(message, n_results=
|
| 133 |
|
| 134 |
print(results)
|
| 135 |
|
|
@@ -196,7 +196,7 @@ with gradio.Blocks() as demo:
|
|
| 196 |
|
| 197 |
# Options
|
| 198 |
model_selector = gradio.Dropdown([(item["name"], item["repo"]) for item in model_options], value="daniel-dona/sparql-model-era-lora-128-qwen3-4b", render=False, interactive=True, label="Model", info="Base model provided as reference, SFT model is trained on generated datasets, GRPO model is reinforced on to of SFT.")
|
| 199 |
-
model_temperature = gradio.Slider(0, 1, render=False, step=0.1, value=0
|
| 200 |
sparql_endpoint = gradio.Textbox(value=SPARQL_ENDPOINT, render=False, interactive=True, label="SPARQL endpoint", info="SPARQL endpoint to send the generate queries to fetch results.")
|
| 201 |
model_semantic = gradio.Checkbox(value=True, render=False, interactive=True, label="Enable semantic entity lookup", info="Use embeddings and reranking model to retrieve relevant objects.")
|
| 202 |
model_thinking = gradio.Checkbox(value=False, render=False, interactive=True, label="Enable thinking", info="Use thinking mode in the Jinja chat template, mostly for GRPO experiments.")
|
|
|
|
| 129 |
|
| 130 |
gradio.Info("Performing semantic search...", duration=5)
|
| 131 |
|
| 132 |
+
results = extractor.extract(message, n_results=5, rerank=True)
|
| 133 |
|
| 134 |
print(results)
|
| 135 |
|
|
|
|
| 196 |
|
| 197 |
# Options
|
| 198 |
model_selector = gradio.Dropdown([(item["name"], item["repo"]) for item in model_options], value="daniel-dona/sparql-model-era-lora-128-qwen3-4b", render=False, interactive=True, label="Model", info="Base model provided as reference, SFT model is trained on generated datasets, GRPO model is reinforced on to of SFT.")
|
| 199 |
+
model_temperature = gradio.Slider(0, 1, render=False, step=0.1, value=0, label="Temperature", info="Ajust model variability, with a value of 0, the model use greedy decoding.")
|
| 200 |
sparql_endpoint = gradio.Textbox(value=SPARQL_ENDPOINT, render=False, interactive=True, label="SPARQL endpoint", info="SPARQL endpoint to send the generate queries to fetch results.")
|
| 201 |
model_semantic = gradio.Checkbox(value=True, render=False, interactive=True, label="Enable semantic entity lookup", info="Use embeddings and reranking model to retrieve relevant objects.")
|
| 202 |
model_thinking = gradio.Checkbox(value=False, render=False, interactive=True, label="Enable thinking", info="Use thinking mode in the Jinja chat template, mostly for GRPO experiments.")
|
src/SemanticSearch.py
CHANGED
|
@@ -201,7 +201,7 @@ class SemanticSearch:
|
|
| 201 |
self.collection = client.get_collection(name=self.collection_name)
|
| 202 |
|
| 203 |
|
| 204 |
-
def extract(self, nlq: str, n_results:int=10, n_candidates:int=50,
|
| 205 |
|
| 206 |
embedding = self.get_text_embeddings_local([nlq])[0].tolist()
|
| 207 |
|
|
@@ -229,6 +229,8 @@ class SemanticSearch:
|
|
| 229 |
{"rank": result[0], "document": result[1], "uri": result[2]}
|
| 230 |
for result in zip(results["rank"][0], results["documents"][0], results["uris"][0])
|
| 231 |
], key=lambda x: x["rank"], reverse=True)
|
|
|
|
|
|
|
| 232 |
|
| 233 |
else:
|
| 234 |
|
|
|
|
| 201 |
self.collection = client.get_collection(name=self.collection_name)
|
| 202 |
|
| 203 |
|
| 204 |
+
def extract(self, nlq: str, n_results:int=10, n_candidates:int=50, rerank:bool=True, rank_cut:float=0.0):
|
| 205 |
|
| 206 |
embedding = self.get_text_embeddings_local([nlq])[0].tolist()
|
| 207 |
|
|
|
|
| 229 |
{"rank": result[0], "document": result[1], "uri": result[2]}
|
| 230 |
for result in zip(results["rank"][0], results["documents"][0], results["uris"][0])
|
| 231 |
], key=lambda x: x["rank"], reverse=True)
|
| 232 |
+
|
| 233 |
+
results = [result for result in results if result["rank"] >= rank_cut]
|
| 234 |
|
| 235 |
else:
|
| 236 |
|