from flask import Flask, request, jsonify from sentence_transformers import CrossEncoder app = Flask(__name__) # Load your cross-encoder model model_name = "truong1301/reranker_pho_BLAI" # Replace with your actual model if different cross_encoder = CrossEncoder(model_name, max_length=256, num_labels=1) # Function to preprocess text with Vietnamese word segmentation def preprocess_text(text): if not text: return text segmented_text = rdrsegmenter.word_segment(text) # Join tokenized sentences into a single string return " ".join([" ".join(sentence) for sentence in segmented_text]) @app.route("/rerank", methods=["POST"]) def rerank(): try: # Get JSON data from the request (query and list of documents) data = request.get_json() query = data.get("query", "") documents = data.get("documents", []) if not query or not documents: return jsonify({"error": "Missing query or documents"}), 400 # Create pairs of query and documents for reranking query_doc_pairs = [(query, doc) for doc in documents] # Get reranking scores from the cross-encoder scores = cross_encoder.predict(query_doc_pairs).tolist() # Combine documents with their scores and sort ranked_results = sorted( [{"document": doc, "score": score} for doc, score in zip(documents, scores)], key=lambda x: x["score"], reverse=True ) return jsonify({"results": ranked_results}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/", methods=["GET"]) def health_check(): return jsonify({"status": "Server is running"}), 200 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860) # Default port for Hugging Face Spaces # from flask import Flask, request, jsonify # from transformers import pipeline # from sentence_transformers import CrossEncoder # app = Flask(__name__) # # Load Vietnamese word segmentation pipeline # segmenter = pipeline("token-classification", model="NlpHUST/vi-word-segmentation") # # Load your cross-encoder model # model_name = "truong1301/reranker_pho_BLAI" # Replace with your actual model if different # cross_encoder = CrossEncoder(model_name, max_length=256, num_labels=1) # # Function to preprocess text using Vietnamese word segmentation # def preprocess_text(text): # if not text: # return text # ner_results = segmenter(text) # segmented_text = "" # for e in ner_results: # if "##" in e["word"]: # segmented_text += e["word"].replace("##", "") # elif e["entity"] == "I": # segmented_text += "_" + e["word"] # else: # segmented_text += " " + e["word"] # return segmented_text.strip() # @app.route("/rerank", methods=["POST"]) # def rerank(): # try: # # Get JSON data from the request (query and list of documents) # data = request.get_json() # query = data.get("query", "") # documents = data.get("documents", []) # if not query or not documents: # return jsonify({"error": "Missing query or documents"}), 400 # # Apply Vietnamese word segmentation preprocessing # segmented_query = preprocess_text(query) # segmented_documents = [preprocess_text(doc) for doc in documents] # # Create pairs of query and documents for reranking # query_doc_pairs = [(segmented_query, doc) for doc in segmented_documents] # # Get reranking scores from the cross-encoder # scores = cross_encoder.predict(query_doc_pairs).tolist() # # Combine documents with their scores and sort # ranked_results = sorted( # [{"document": doc, "score": score} for doc, score in zip(documents, scores)], # key=lambda x: x["score"], # reverse=True # ) # return jsonify({"results": ranked_results}) # except Exception as e: # return jsonify({"error": str(e)}), 500 # @app.route("/", methods=["GET"]) # def health_check(): # return jsonify({"status": "Server is running"}), 200 # if __name__ == "__main__": # app.run(host="0.0.0.0", port=7860) # Default port for Hugging Face Spaces