daniel-dona commited on
Commit
ceccbee
·
1 Parent(s): cd181fe
Files changed (3) hide show
  1. app.py +1 -39
  2. src/Inference.py +43 -0
  3. src/SemanticSearch.py +10 -2
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import os
2
  import json
3
 
4
- import spaces
5
  import gradio
6
 
7
  import numpy
8
  import pandas
9
 
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  import pyparseit
13
 
@@ -35,50 +33,14 @@ model_options = [
35
 
36
 
37
  from src.SemanticSearch import SemanticSearch
 
38
 
39
  extractor = SemanticSearch()
40
  extractor.load_ne_from_kg(SPARQL_ENDPOINT)
41
  extractor.build_vector_db()
42
  extractor.load_vector_db()
43
 
44
- @spaces.GPU
45
- def model_completion(messages, model_name, model_temperature, model_thinking):
46
 
47
- # load the tokenizer and the model
48
- tokenizer = AutoTokenizer.from_pretrained(model_name)
49
- model = AutoModelForCausalLM.from_pretrained(
50
- model_name,
51
- torch_dtype="auto",
52
- device_map="auto"
53
- )
54
-
55
- text = tokenizer.apply_chat_template(
56
- messages,
57
- tokenize=False,
58
- add_generation_prompt=True,
59
- enable_thinking=model_thinking
60
- )
61
-
62
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
63
-
64
- sample = True
65
-
66
- if model_temperature == 0:
67
- sample = False
68
-
69
-
70
- # conduct text completion
71
- generated_ids = model.generate(
72
- **model_inputs,
73
- max_new_tokens=4096,
74
- do_sample=sample,
75
- temperature=model_temperature
76
- )
77
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
78
-
79
- content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
80
-
81
- return content
82
 
83
 
84
  def sparql_json_to_df(sparql_json):
 
1
  import os
2
  import json
3
 
 
4
  import gradio
5
 
6
  import numpy
7
  import pandas
8
 
 
9
 
10
  import pyparseit
11
 
 
33
 
34
 
35
  from src.SemanticSearch import SemanticSearch
36
+ from src.Inference import model_completion
37
 
38
  extractor = SemanticSearch()
39
  extractor.load_ne_from_kg(SPARQL_ENDPOINT)
40
  extractor.build_vector_db()
41
  extractor.load_vector_db()
42
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def sparql_json_to_df(sparql_json):
src/Inference.py CHANGED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+
6
+ @spaces.GPU
7
+ def model_completion(messages, model_name, model_temperature, model_thinking):
8
+
9
+ # load the tokenizer and the model
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype="auto",
14
+ device_map="auto"
15
+ )
16
+
17
+ text = tokenizer.apply_chat_template(
18
+ messages,
19
+ tokenize=False,
20
+ add_generation_prompt=True,
21
+ enable_thinking=model_thinking
22
+ )
23
+
24
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
25
+
26
+ sample = True
27
+
28
+ if model_temperature == 0:
29
+ sample = False
30
+
31
+
32
+ # conduct text completion
33
+ generated_ids = model.generate(
34
+ **model_inputs,
35
+ max_new_tokens=4096,
36
+ do_sample=sample,
37
+ temperature=model_temperature
38
+ )
39
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
40
+
41
+ content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
42
+
43
+ return content
src/SemanticSearch.py CHANGED
@@ -30,9 +30,17 @@ WHERE {
30
  #FILTER(lang(?ne_label) = "en" || lang(?ne_label) = "")
31
  #FILTER(lang(?class_label) = "en" || lang(?class_label) = "")
32
  }
33
- LIMIT 128
34
  """
35
 
 
 
 
 
 
 
 
 
 
36
  class SemanticSearch:
37
 
38
  def __init__(self, embeddings_model="BAAI/bge-base-en-v1.5", reranking_model="BAAI/bge-reranker-v2-m3"):
@@ -174,7 +182,7 @@ class SemanticSearch:
174
 
175
  print("Got ", len(documents), "sentences")
176
 
177
- for sentences_batch in tqdm.tqdm(list(itertools.batched(documents, 512)), desc="Generating embeddings"):
178
 
179
  embeddings += self.get_text_embeddings_local(sentences_batch)
180
 
 
30
  #FILTER(lang(?ne_label) = "en" || lang(?ne_label) = "")
31
  #FILTER(lang(?class_label) = "en" || lang(?class_label) = "")
32
  }
 
33
  """
34
 
35
+ # HF seems to use 3.10!
36
+ def batched(iterable, n):
37
+ if n < 1:
38
+ raise ValueError('n must be at least one')
39
+ it = iter(iterable)
40
+
41
+ while batch := tuple(itertools.islice(it, n)):
42
+ yield batch
43
+
44
  class SemanticSearch:
45
 
46
  def __init__(self, embeddings_model="BAAI/bge-base-en-v1.5", reranking_model="BAAI/bge-reranker-v2-m3"):
 
182
 
183
  print("Got ", len(documents), "sentences")
184
 
185
+ for sentences_batch in tqdm.tqdm(list(batched(documents, 512)), desc="Generating embeddings"):
186
 
187
  embeddings += self.get_text_embeddings_local(sentences_batch)
188