Spaces:
Runtime error
Runtime error
| import sys | |
| import uuid | |
| import tqdm | |
| import json | |
| import re | |
| import requests | |
| import itertools | |
| import torch | |
| import chromadb | |
| from SPARQLWrapper import SPARQLWrapper, JSON | |
| from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer | |
| ne_query = """ | |
| PREFIX ne: <http://www.w3.org/2004/02/ne/core#> | |
| PREFIX era: <http://data.europa.eu/949/> | |
| SELECT ?ne_label ?ne_uri ?class_label ?class_uri | |
| WHERE { | |
| VALUES ?class_uri {era:OperationalPoint era:SectionOfLine } | |
| ?ne_uri a ?class_uri. | |
| ?ne_uri ne:prefLabel|ne:altLabel|rdfs:label ?ne_label . | |
| ?class_uri ne:prefLabel|ne:altLabel|rdfs:label ?class_label . | |
| #?ne_uri era:inCountry <http://publications.europa.eu/resource/authority/country/ESP> . | |
| #FILTER(STRSTARTS(STR(?ne_uri), "http://data.europa.eu/949/")) | |
| #FILTER(lang(?ne_label) = "en" || lang(?ne_label) = "") | |
| #FILTER(lang(?class_label) = "en" || lang(?class_label) = "") | |
| } | |
| """ | |
| # HF seems to use 3.10! | |
| def batched(iterable, n): | |
| if n < 1: | |
| raise ValueError('n must be at least one') | |
| it = iter(iterable) | |
| while batch := tuple(itertools.islice(it, n)): | |
| yield batch | |
| class SemanticSearch: | |
| def __init__(self, embeddings_model="BAAI/bge-base-en-v1.5", reranking_model="BAAI/bge-reranker-v2-m3"): | |
| self.embeddings_tokenizer = AutoTokenizer.from_pretrained(embeddings_model) | |
| self.embeddings_model = AutoModel.from_pretrained(embeddings_model) | |
| self.reranking_tokenizer = AutoTokenizer.from_pretrained(reranking_model) | |
| self.reranking_model = AutoModelForSequenceClassification.from_pretrained(reranking_model) | |
| self.ne_values = [] | |
| self.collection_name = "kg_ne" | |
| def load_ne_from_kg(self, endpoint: str): | |
| sparql = SPARQLWrapper(endpoint) | |
| sparql.setQuery(ne_query) | |
| sparql.setReturnFormat(JSON) | |
| try: | |
| raw_results = sparql.query().convert() | |
| except Exception as e: | |
| print(f"Error querying the endpoint: {e}", file=sys.stderr) | |
| return None | |
| processed = [] | |
| for binding in raw_results.get("results", {}).get("bindings", []): | |
| # Extract values from bindings | |
| ne_label = binding["ne_label"]["value"] | |
| ne_uri = binding["ne_uri"]["value"] | |
| class_uri = binding["class_uri"]["value"] | |
| class_label = binding["class_label"]["value"] | |
| processed.append({ | |
| 'ne_label': ne_label, | |
| 'ne_uri': ne_uri, | |
| 'class_uri': class_uri, | |
| 'class_label': class_label | |
| }) | |
| self.ne_values += processed | |
| def load_ne_from_file(self, filename: str): | |
| ne_file = open(filename).read() | |
| self.ne_values += json.loads(ne_file) | |
| def save_ne_to_file(self, filename: str): | |
| with open(filename, "w+") as ne_file: | |
| ne_file.write(json.dumps(self.ne_values)) | |
| def get_text_embeddings_local(self, sentences): | |
| encoded_input = self.embeddings_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') | |
| with torch.no_grad(): | |
| model_output = self.embeddings_model(**encoded_input) | |
| sentence_embeddings = model_output[0][:, 0] | |
| #print(sentence_embeddings) | |
| sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | |
| #print("Sentence embeddings:", sentence_embeddings) | |
| return sentence_embeddings | |
| def get_text_embeddings_api(self, endpoint, sentences): | |
| embeddings = [] | |
| for sentence in sentences: | |
| response = requests.post( | |
| f"{endpoint}/v1/embeddings", | |
| headers={"Content-Type": "application/json"}, | |
| data=json.dumps( | |
| { | |
| "model": self.embeddings_model, | |
| "input": sentences, | |
| "encoding_format": "float" | |
| } | |
| ) | |
| ) | |
| embeddings.append(response.json()["data"][0]["embedding"]) | |
| return embeddings.tolist() | |
| def get_reranked_results(self, query, options): | |
| pairs = [] | |
| for option in options: | |
| pairs.append([query, option]) | |
| with torch.no_grad(): | |
| inputs = self.reranking_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) | |
| scores = self.reranking_model(**inputs, return_dict=True).logits.view(-1, ).float() | |
| results = [] | |
| for score, pair in zip(scores, pairs): | |
| results.append([score, pair]) | |
| return results | |
| def build_vector_db(self): | |
| client = chromadb.PersistentClient(path="./chroma_db") | |
| try: | |
| client.delete_collection(name=self.collection_name) | |
| print("Deleted existing collection") | |
| except: | |
| print("No existing collection") | |
| self.collection = client.create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"}) | |
| ids = [] | |
| documents = [] | |
| uris = [] | |
| metadatas = [] | |
| embeddings = [] | |
| for entry in self.ne_values: | |
| #term = f"{entry["ne_label"]} | {entry["class_label"]}" | |
| term = f"{entry['ne_label']}" # {re.sub(r'[^\w\s]','',entry["ne_label"])} {re.sub(r'[^\w\s]',' ',entry["ne_label"])}" | |
| ids.append(str(uuid.uuid4())) | |
| documents.append(term) | |
| uris.append(entry["ne_uri"]) | |
| metadatas.append(entry) | |
| print("Got ", len(documents), "sentences") | |
| for sentences_batch in tqdm.tqdm(list(batched(documents, 512)), desc="Generating embeddings"): | |
| embeddings += self.get_text_embeddings_local(sentences_batch) | |
| embeddings = [embedding.tolist() for embedding in embeddings] | |
| print(len(embeddings)) | |
| self.collection.add(ids=ids, documents=documents, embeddings=embeddings, uris=uris, metadatas=metadatas) | |
| def load_vector_db(self): | |
| client = chromadb.PersistentClient(path="./chroma_db") | |
| self.collection = client.get_collection(name=self.collection_name) | |
| def extract(self, nlq: str, n_results:int=10, n_candidates:int=50, rerank:bool=True, rank_cut:float=0.0): | |
| embedding = self.get_text_embeddings_local([nlq])[0].tolist() | |
| results = self.collection.query( | |
| query_embeddings=[embedding], | |
| n_results=n_candidates, | |
| include=["documents", "distances", "metadatas", "uris"] | |
| ) | |
| print(results) | |
| if rerank: | |
| documents = [f"'{item.lower()}' ({cls['class_label']})" for item, cls in zip(results["documents"][0], results["metadatas"][0])] | |
| rank_results = self.get_reranked_results(nlq, documents) | |
| print(rank_results) | |
| rank = [float(item[0]) for item in rank_results] | |
| results["rank"] = [rank] | |
| results = sorted([ | |
| {"rank": result[0], "document": result[1], "uri": result[2]} | |
| for result in zip(results["rank"][0], results["documents"][0], results["uris"][0]) | |
| ], key=lambda x: x["rank"], reverse=True) | |
| results = [result for result in results if result["rank"] >= rank_cut] | |
| else: | |
| results = [ | |
| {"document": result[0], "uri": result[1]} | |
| for result in zip(results["documents"][0], results["uris"][0]) | |
| ] | |
| return results | |