Added handler
Browse files- handler.py +151 -161
handler.py
CHANGED
@@ -7,133 +7,14 @@ from transformers import GenerationConfig
|
|
7 |
import transformers
|
8 |
import pandas as pd
|
9 |
import time
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
emb_gpt_genes = pd.read_pickle('./multi-modal-data/emb_gpt_genes.pickle')
|
14 |
-
emb_hgt_genes = pd.read_pickle('./multi-modal-data/emb_hgt_genes.pickle')
|
15 |
-
|
16 |
-
|
17 |
-
def create_prompt(prompt_config):
|
18 |
-
|
19 |
-
prompt = "[BOS]"
|
20 |
-
|
21 |
-
multi_modal_prefix = '<modality0><modality1><modality2><modality3>'*3
|
22 |
-
|
23 |
-
for k, v in prompt_config.items():
|
24 |
-
if k=='instruction':
|
25 |
-
prompt+=f"<{v}>"
|
26 |
-
elif k=='up':
|
27 |
-
prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
|
28 |
-
elif k=='down':
|
29 |
-
prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
|
30 |
-
else:
|
31 |
-
prompt+=f'<{k}>{v}</{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
|
32 |
-
return prompt
|
33 |
-
|
34 |
-
def custom_generate(input_ids,
|
35 |
-
acc_embs_up_kg_mean,
|
36 |
-
acc_embs_down_kg_mean,
|
37 |
-
acc_embs_up_txt_mean,
|
38 |
-
acc_embs_down_txt_mean,
|
39 |
-
device,
|
40 |
-
max_new_tokens,
|
41 |
-
num_return_sequences,
|
42 |
-
temperature=0.8,
|
43 |
-
top_p=0.2, top_k=3550, n_next_tokens=50,
|
44 |
-
unique_compounds):
|
45 |
-
torch.manual_seed(137)
|
46 |
-
|
47 |
-
# Set parameters
|
48 |
-
# temperature - Higher value for more randomness, lower for more control
|
49 |
-
# top_p - Probability threshold for nucleus sampling (aka top-p sampling)
|
50 |
-
# top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
|
51 |
-
# n_next_tokens - Number of top next tokens when predicting compounds
|
52 |
-
|
53 |
-
modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
|
54 |
-
modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device)
|
55 |
-
modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
|
56 |
-
modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device)
|
57 |
-
|
58 |
-
|
59 |
-
# Generate sequences
|
60 |
-
outputs = []
|
61 |
-
next_token_compounds = []
|
62 |
-
|
63 |
-
for _ in range(num_return_sequences):
|
64 |
-
start_time = time.time()
|
65 |
-
generated_sequence = []
|
66 |
-
current_token = input_ids.clone()
|
67 |
-
|
68 |
-
for _ in range(max_new_tokens): # Maximum length of generated sequence
|
69 |
-
# Forward pass through the model
|
70 |
-
logits = model.forward(input_ids=current_token,
|
71 |
-
modality0_emb=modality0_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
72 |
-
modality0_token_id=62191,
|
73 |
-
modality1_emb=modality1_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
74 |
-
modality1_token_id=62192,
|
75 |
-
modality2_emb=modality2_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
76 |
-
modality2_token_id=62193,
|
77 |
-
modality3_emb=modality3_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
78 |
-
modality3_token_id=62194)[0]
|
79 |
-
|
80 |
-
# Apply temperature to logits
|
81 |
-
if temperature != 1.0:
|
82 |
-
logits = logits / temperature
|
83 |
-
|
84 |
-
# Apply top-p sampling (nucleus sampling)
|
85 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
86 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
87 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
88 |
-
|
89 |
-
if top_k > 0:
|
90 |
-
sorted_indices_to_remove[..., top_k:] = 1
|
91 |
-
|
92 |
-
# Set the logit values of the removed indices to a very small negative value
|
93 |
-
inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
|
94 |
-
|
95 |
-
logits = logits.where(sorted_indices_to_remove, inf_tensor)
|
96 |
-
|
97 |
-
|
98 |
-
# Sample the next token
|
99 |
-
if current_token[0][-1] == tokenizer.encode('<drug>')[0]:
|
100 |
-
next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), 50).indices)
|
101 |
-
|
102 |
-
next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
|
103 |
-
|
104 |
-
|
105 |
-
# Append the sampled token to the generated sequence
|
106 |
-
generated_sequence.append(next_token.item())
|
107 |
-
|
108 |
-
Stop generation if an end token is generated
|
109 |
-
if next_token == tokenizer.eos_token_id:
|
110 |
-
break
|
111 |
-
|
112 |
-
# Prepare input for the next iteration
|
113 |
-
current_token = torch.cat((current_token, next_token), dim=-1)
|
114 |
-
print(time.time()-start_time)
|
115 |
-
outputs.append(generated_sequence)
|
116 |
-
return outputs, next_token_compounds
|
117 |
-
|
118 |
-
|
119 |
-
def get_predicted_compounds(input_ids, generation_output, tokenizer, p3_compounds):
|
120 |
-
id_4_drug_token = list(generation_output.sequences[0][len(input_ids[0]):]).index(tokenizer.convert_tokens_to_ids(['<drug>'])[0])
|
121 |
-
id_4_drug_token += 1
|
122 |
-
print('This is token index where drug should be predicted: ', id_4_drug_token)
|
123 |
-
|
124 |
-
values, indices = torch.topk(generation_output["scores"][id_4_drug_token].view(-1), k=50)
|
125 |
-
indices_decoded = tokenizer.decode(indices, skip_special_tokens=True)
|
126 |
-
|
127 |
-
predicted_compound = indices_decoded.split(' ')
|
128 |
-
predicted_compound = [i.strip() for i in predicted_compound]
|
129 |
-
|
130 |
-
valid_compounds = sorted(set(predicted_compound) & set(p3_compounds), key = predicted_compound.index)
|
131 |
-
print(f"Model predicted {len(predicted_compound)} tokens. Valid compounds {len(valid_compounds)}")
|
132 |
-
return valid_compounds
|
133 |
|
134 |
|
135 |
class EndpointHandler:
|
136 |
def __init__(self, path=""):
|
|
|
|
|
137 |
# load model and processor from path
|
138 |
self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
|
139 |
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
|
@@ -143,10 +24,104 @@ class EndpointHandler:
|
|
143 |
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
144 |
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
145 |
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
146 |
-
unique_entities_p3 = pd.read_csv(os.path.join(path, '
|
147 |
self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
|
150 |
|
151 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
152 |
"""
|
@@ -154,42 +129,57 @@ class EndpointHandler:
|
|
154 |
data (:dict:):
|
155 |
The payload with the text prompt and generation parameters.
|
156 |
"""
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
parameters = data.pop("parameters", None)
|
160 |
mode = data.pop('mode', 'diff2compound')
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import transformers
|
8 |
import pandas as pd
|
9 |
import time
|
10 |
+
import numpy as np
|
11 |
+
from precious3_gpt_multi_modal import Custom_MPTForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class EndpointHandler:
|
15 |
def __init__(self, path=""):
|
16 |
+
|
17 |
+
self.path = path
|
18 |
# load model and processor from path
|
19 |
self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
|
20 |
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
|
|
|
24 |
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
25 |
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
26 |
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
27 |
+
unique_entities_p3 = pd.read_csv(os.path.join(path, 'p3_entities_with_type.csv'))
|
28 |
self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
|
29 |
+
|
30 |
+
self.emb_gpt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_gpt_genes.pickle'))
|
31 |
+
self.emb_hgt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_hgt_genes.pickle'))
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def custom_generate(self,
|
36 |
+
input_ids,
|
37 |
+
acc_embs_up_kg_mean,
|
38 |
+
acc_embs_down_kg_mean,
|
39 |
+
acc_embs_up_txt_mean,
|
40 |
+
acc_embs_down_txt_mean,
|
41 |
+
device,
|
42 |
+
max_new_tokens,
|
43 |
+
unique_compounds,
|
44 |
+
temperature=0.8,
|
45 |
+
top_p=0.2, top_k=3550,
|
46 |
+
n_next_tokens=50, num_return_sequences=1):
|
47 |
+
torch.manual_seed(137)
|
48 |
+
|
49 |
+
# Set parameters
|
50 |
+
# temperature - Higher value for more randomness, lower for more control
|
51 |
+
# top_p - Probability threshold for nucleus sampling (aka top-p sampling)
|
52 |
+
# top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
|
53 |
+
# n_next_tokens - Number of top next tokens when predicting compounds
|
54 |
+
|
55 |
+
modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device)
|
56 |
+
modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device)
|
57 |
+
modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device)
|
58 |
+
modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device)
|
59 |
+
|
60 |
+
|
61 |
+
# Generate sequences
|
62 |
+
outputs = []
|
63 |
+
next_token_compounds = []
|
64 |
+
|
65 |
+
for _ in range(num_return_sequences):
|
66 |
+
start_time = time.time()
|
67 |
+
generated_sequence = []
|
68 |
+
current_token = input_ids.clone()
|
69 |
+
|
70 |
+
for _ in range(max_new_tokens): # Maximum length of generated sequence
|
71 |
+
# Forward pass through the model
|
72 |
+
logits = self.model.forward(input_ids=current_token,
|
73 |
+
modality0_emb=modality0_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
74 |
+
modality0_token_id=62191,
|
75 |
+
modality1_emb=modality1_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
76 |
+
modality1_token_id=62192,
|
77 |
+
modality2_emb=modality2_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
78 |
+
modality2_token_id=62193,
|
79 |
+
modality3_emb=modality3_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
|
80 |
+
modality3_token_id=62194)[0]
|
81 |
+
|
82 |
+
# Apply temperature to logits
|
83 |
+
if temperature != 1.0:
|
84 |
+
logits = logits / temperature
|
85 |
+
|
86 |
+
# Apply top-p sampling (nucleus sampling)
|
87 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
88 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
89 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
90 |
+
|
91 |
+
if top_k > 0:
|
92 |
+
sorted_indices_to_remove[..., top_k:] = 1
|
93 |
+
|
94 |
+
# Set the logit values of the removed indices to a very small negative value
|
95 |
+
inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
|
96 |
+
|
97 |
+
logits = logits.where(sorted_indices_to_remove, inf_tensor)
|
98 |
+
|
99 |
+
|
100 |
+
# Sample the next token
|
101 |
+
if current_token[0][-1] == self.tokenizer.encode('<drug>')[0]:
|
102 |
+
next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), n_next_tokens).indices)
|
103 |
+
|
104 |
+
next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
|
105 |
+
|
106 |
+
|
107 |
+
# Append the sampled token to the generated sequence
|
108 |
+
generated_sequence.append(next_token.item())
|
109 |
+
|
110 |
+
# Stop generation if an end token is generated
|
111 |
+
if next_token == self.tokenizer.eos_token_id:
|
112 |
+
break
|
113 |
+
|
114 |
+
# Prepare input for the next iteration
|
115 |
+
current_token = torch.cat((current_token, next_token), dim=-1)
|
116 |
+
print(time.time()-start_time)
|
117 |
+
outputs.append(generated_sequence)
|
118 |
+
|
119 |
+
predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds]
|
120 |
+
predicted_compounds = []
|
121 |
+
for j in predicted_compounds_ids:
|
122 |
+
predicted_compounds.append([i.strip() for i in j])
|
123 |
+
return outputs, predicted_compounds
|
124 |
|
|
|
125 |
|
126 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
127 |
"""
|
|
|
129 |
data (:dict:):
|
130 |
The payload with the text prompt and generation parameters.
|
131 |
"""
|
132 |
+
torch.manual_seed(137)
|
133 |
+
|
134 |
+
device = "cuda"
|
135 |
+
prompt = data.pop("inputs", None)
|
136 |
parameters = data.pop("parameters", None)
|
137 |
mode = data.pop('mode', 'diff2compound')
|
138 |
|
139 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
140 |
+
input_ids = inputs["input_ids"].to(device)
|
141 |
+
|
142 |
+
max_new_tokens = 600 - len(input_ids[0])
|
143 |
+
try:
|
144 |
+
acc_embs_up1 = []
|
145 |
+
acc_embs_up2 = []
|
146 |
+
for gs in config_data['up']:
|
147 |
+
try:
|
148 |
+
acc_embs_up1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
|
149 |
+
acc_embs_up2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
|
150 |
+
except Exception as e:
|
151 |
+
pass
|
152 |
+
acc_embs_up1_mean = np.array(acc_embs_up1).mean(0)
|
153 |
+
acc_embs_up2_mean = np.array(acc_embs_up2).mean(0)
|
154 |
+
|
155 |
+
acc_embs_down1 = []
|
156 |
+
acc_embs_down2 = []
|
157 |
+
for gs in config_data['down']:
|
158 |
+
try:
|
159 |
+
acc_embs_down1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
|
160 |
+
acc_embs_down2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
|
161 |
+
except Exception as e:
|
162 |
+
pass
|
163 |
+
acc_embs_down1_mean = np.array(acc_embs_down1).mean(0)
|
164 |
+
acc_embs_down2_mean = np.array(acc_embs_down2).mean(0)
|
165 |
+
|
166 |
+
generated_sequence, raw_next_token_generation = self.custom_generate(input_ids = input_ids,
|
167 |
+
acc_embs_up_kg_mean=acc_embs_up1_mean,
|
168 |
+
acc_embs_down_kg_mean=acc_embs_down1_mean,
|
169 |
+
acc_embs_up_txt_mean=acc_embs_up2_mean,
|
170 |
+
acc_embs_down_txt_mean=acc_embs_down2_mean, max_new_tokens=max_new_tokens,
|
171 |
+
device=device, unique_compounds=self.unique_compounds_p3, **parameters)
|
172 |
+
next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key = i.index) for i in raw_next_token_generation]
|
173 |
+
|
174 |
+
if mode == "meta2diff":
|
175 |
+
outputs = {"up": generated_sequence, "down": generated_sequence}
|
176 |
+
else:
|
177 |
+
outputs = generated_sequence
|
178 |
+
out = {"output": outputs, 'compounds': next_token_generation, "raw_output": raw_next_token_generation, "mode": mode, 'message': "Done!"}
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
print(e)
|
182 |
+
outputs, next_token_generation = [None], [None]
|
183 |
+
out = {"output": outputs, "mode": mode, 'message': f"{e}"}
|
184 |
+
|
185 |
+
return out
|