ICTuniverse commited on
Commit
6378099
·
verified ·
1 Parent(s): 32563c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -8
app.py CHANGED
@@ -1,22 +1,90 @@
1
- from flask import Flask, request, jsonify
2
- from sentence_transformers import CrossEncoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  app = Flask(__name__)
7
 
 
 
8
 
9
  # Load your cross-encoder model
10
  model_name = "truong1301/reranker_pho_BLAI" # Replace with your actual model if different
11
  cross_encoder = CrossEncoder(model_name, max_length=256, num_labels=1)
12
 
13
- # Function to preprocess text with Vietnamese word segmentation
14
  def preprocess_text(text):
15
  if not text:
16
  return text
17
- segmented_text = rdrsegmenter.word_segment(text)
18
- # Join tokenized sentences into a single string
19
- return " ".join([" ".join(sentence) for sentence in segmented_text])
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.route("/rerank", methods=["POST"])
22
  def rerank():
@@ -29,8 +97,12 @@ def rerank():
29
  if not query or not documents:
30
  return jsonify({"error": "Missing query or documents"}), 400
31
 
 
 
 
 
32
  # Create pairs of query and documents for reranking
33
- query_doc_pairs = [(query, doc) for doc in documents]
34
 
35
  # Get reranking scores from the cross-encoder
36
  scores = cross_encoder.predict(query_doc_pairs).tolist()
@@ -55,4 +127,3 @@ if __name__ == "__main__":
55
  app.run(host="0.0.0.0", port=7860) # Default port for Hugging Face Spaces
56
 
57
 
58
-
 
1
+ # from flask import Flask, request, jsonify
2
+ # from sentence_transformers import CrossEncoder
3
+
4
+
5
+
6
+ # app = Flask(__name__)
7
+
8
+
9
+ # # Load your cross-encoder model
10
+ # model_name = "truong1301/reranker_pho_BLAI" # Replace with your actual model if different
11
+ # cross_encoder = CrossEncoder(model_name, max_length=256, num_labels=1)
12
+
13
+ # # Function to preprocess text with Vietnamese word segmentation
14
+ # def preprocess_text(text):
15
+ # if not text:
16
+ # return text
17
+ # segmented_text = rdrsegmenter.word_segment(text)
18
+ # # Join tokenized sentences into a single string
19
+ # return " ".join([" ".join(sentence) for sentence in segmented_text])
20
+
21
+ # @app.route("/rerank", methods=["POST"])
22
+ # def rerank():
23
+ # try:
24
+ # # Get JSON data from the request (query and list of documents)
25
+ # data = request.get_json()
26
+ # query = data.get("query", "")
27
+ # documents = data.get("documents", [])
28
+
29
+ # if not query or not documents:
30
+ # return jsonify({"error": "Missing query or documents"}), 400
31
+
32
+ # # Create pairs of query and documents for reranking
33
+ # query_doc_pairs = [(query, doc) for doc in documents]
34
+
35
+ # # Get reranking scores from the cross-encoder
36
+ # scores = cross_encoder.predict(query_doc_pairs).tolist()
37
 
38
+ # # Combine documents with their scores and sort
39
+ # ranked_results = sorted(
40
+ # [{"document": doc, "score": score} for doc, score in zip(documents, scores)],
41
+ # key=lambda x: x["score"],
42
+ # reverse=True
43
+ # )
44
 
45
+ # return jsonify({"results": ranked_results})
46
+
47
+ # except Exception as e:
48
+ # return jsonify({"error": str(e)}), 500
49
+
50
+ # @app.route("/", methods=["GET"])
51
+ # def health_check():
52
+ # return jsonify({"status": "Server is running"}), 200
53
+
54
+ # if __name__ == "__main__":
55
+ # app.run(host="0.0.0.0", port=7860) # Default port for Hugging Face Spaces
56
+
57
+
58
+ from flask import Flask, request, jsonify
59
+ from transformers import pipeline
60
+ from sentence_transformers import CrossEncoder
61
 
62
  app = Flask(__name__)
63
 
64
+ # Load Vietnamese word segmentation pipeline
65
+ segmenter = pipeline("token-classification", model="NlpHUST/vi-word-segmentation")
66
 
67
  # Load your cross-encoder model
68
  model_name = "truong1301/reranker_pho_BLAI" # Replace with your actual model if different
69
  cross_encoder = CrossEncoder(model_name, max_length=256, num_labels=1)
70
 
71
+ # Function to preprocess text using Vietnamese word segmentation
72
  def preprocess_text(text):
73
  if not text:
74
  return text
75
+
76
+ ner_results = segmenter(text)
77
+ segmented_text = ""
78
+
79
+ for e in ner_results:
80
+ if "##" in e["word"]:
81
+ segmented_text += e["word"].replace("##", "")
82
+ elif e["entity"] == "I":
83
+ segmented_text += "_" + e["word"]
84
+ else:
85
+ segmented_text += " " + e["word"]
86
+
87
+ return segmented_text.strip()
88
 
89
  @app.route("/rerank", methods=["POST"])
90
  def rerank():
 
97
  if not query or not documents:
98
  return jsonify({"error": "Missing query or documents"}), 400
99
 
100
+ # Apply Vietnamese word segmentation preprocessing
101
+ segmented_query = preprocess_text(query)
102
+ segmented_documents = [preprocess_text(doc) for doc in documents]
103
+
104
  # Create pairs of query and documents for reranking
105
+ query_doc_pairs = [(segmented_query, doc) for doc in segmented_documents]
106
 
107
  # Get reranking scores from the cross-encoder
108
  scores = cross_encoder.predict(query_doc_pairs).tolist()
 
127
  app.run(host="0.0.0.0", port=7860) # Default port for Hugging Face Spaces
128
 
129