stefan-insilico commited on
Commit
f07c563
·
verified ·
1 Parent(s): b620105

Added handler

Browse files
Files changed (1) hide show
  1. handler.py +151 -161
handler.py CHANGED
@@ -7,133 +7,14 @@ from transformers import GenerationConfig
7
  import transformers
8
  import pandas as pd
9
  import time
10
- from precious3_gpt_multi_model import Custom_MPTForCausalLM
11
-
12
-
13
- emb_gpt_genes = pd.read_pickle('./multi-modal-data/emb_gpt_genes.pickle')
14
- emb_hgt_genes = pd.read_pickle('./multi-modal-data/emb_hgt_genes.pickle')
15
-
16
-
17
- def create_prompt(prompt_config):
18
-
19
- prompt = "[BOS]"
20
-
21
- multi_modal_prefix = '<modality0><modality1><modality2><modality3>'*3
22
-
23
- for k, v in prompt_config.items():
24
- if k=='instruction':
25
- prompt+=f"<{v}>"
26
- elif k=='up':
27
- prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
28
- elif k=='down':
29
- prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
30
- else:
31
- prompt+=f'<{k}>{v}</{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
32
- return prompt
33
-
34
- def custom_generate(input_ids,
35
- acc_embs_up_kg_mean,
36
- acc_embs_down_kg_mean,
37
- acc_embs_up_txt_mean,
38
- acc_embs_down_txt_mean,
39
- device,
40
- max_new_tokens,
41
- num_return_sequences,
42
- temperature=0.8,
43
- top_p=0.2, top_k=3550, n_next_tokens=50,
44
- unique_compounds):
45
- torch.manual_seed(137)
46
-
47
- # Set parameters
48
- # temperature - Higher value for more randomness, lower for more control
49
- # top_p - Probability threshold for nucleus sampling (aka top-p sampling)
50
- # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
51
- # n_next_tokens - Number of top next tokens when predicting compounds
52
-
53
- modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
54
- modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device)
55
- modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
56
- modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device)
57
-
58
-
59
- # Generate sequences
60
- outputs = []
61
- next_token_compounds = []
62
-
63
- for _ in range(num_return_sequences):
64
- start_time = time.time()
65
- generated_sequence = []
66
- current_token = input_ids.clone()
67
-
68
- for _ in range(max_new_tokens): # Maximum length of generated sequence
69
- # Forward pass through the model
70
- logits = model.forward(input_ids=current_token,
71
- modality0_emb=modality0_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
72
- modality0_token_id=62191,
73
- modality1_emb=modality1_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
74
- modality1_token_id=62192,
75
- modality2_emb=modality2_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
76
- modality2_token_id=62193,
77
- modality3_emb=modality3_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
78
- modality3_token_id=62194)[0]
79
-
80
- # Apply temperature to logits
81
- if temperature != 1.0:
82
- logits = logits / temperature
83
-
84
- # Apply top-p sampling (nucleus sampling)
85
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
86
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
87
- sorted_indices_to_remove = cumulative_probs > top_p
88
-
89
- if top_k > 0:
90
- sorted_indices_to_remove[..., top_k:] = 1
91
-
92
- # Set the logit values of the removed indices to a very small negative value
93
- inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
94
-
95
- logits = logits.where(sorted_indices_to_remove, inf_tensor)
96
-
97
-
98
- # Sample the next token
99
- if current_token[0][-1] == tokenizer.encode('<drug>')[0]:
100
- next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), 50).indices)
101
-
102
- next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
103
-
104
-
105
- # Append the sampled token to the generated sequence
106
- generated_sequence.append(next_token.item())
107
-
108
- Stop generation if an end token is generated
109
- if next_token == tokenizer.eos_token_id:
110
- break
111
-
112
- # Prepare input for the next iteration
113
- current_token = torch.cat((current_token, next_token), dim=-1)
114
- print(time.time()-start_time)
115
- outputs.append(generated_sequence)
116
- return outputs, next_token_compounds
117
-
118
-
119
- def get_predicted_compounds(input_ids, generation_output, tokenizer, p3_compounds):
120
- id_4_drug_token = list(generation_output.sequences[0][len(input_ids[0]):]).index(tokenizer.convert_tokens_to_ids(['<drug>'])[0])
121
- id_4_drug_token += 1
122
- print('This is token index where drug should be predicted: ', id_4_drug_token)
123
-
124
- values, indices = torch.topk(generation_output["scores"][id_4_drug_token].view(-1), k=50)
125
- indices_decoded = tokenizer.decode(indices, skip_special_tokens=True)
126
-
127
- predicted_compound = indices_decoded.split(' ')
128
- predicted_compound = [i.strip() for i in predicted_compound]
129
-
130
- valid_compounds = sorted(set(predicted_compound) & set(p3_compounds), key = predicted_compound.index)
131
- print(f"Model predicted {len(predicted_compound)} tokens. Valid compounds {len(valid_compounds)}")
132
- return valid_compounds
133
 
134
 
135
  class EndpointHandler:
136
  def __init__(self, path=""):
 
 
137
  # load model and processor from path
138
  self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
139
  self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
@@ -143,10 +24,104 @@ class EndpointHandler:
143
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
144
  self.model.config.bos_token_id = self.tokenizer.bos_token_id
145
  self.model.config.eos_token_id = self.tokenizer.eos_token_id
146
- unique_entities_p3 = pd.read_csv(os.path.join(path, 'all_entities_with_type.csv'))
147
  self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
-
150
 
151
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
152
  """
@@ -154,42 +129,57 @@ class EndpointHandler:
154
  data (:dict:):
155
  The payload with the text prompt and generation parameters.
156
  """
157
-
158
- inputs = data.pop("inputs", data)
 
 
159
  parameters = data.pop("parameters", None)
160
  mode = data.pop('mode', 'diff2compound')
161
 
162
- if mode == 'diff2compound':
163
- with open('./generation-configs/diff2compound.json', 'r') as f:
164
- config_data = json.load(f)
165
- else:
166
- with open('./generation-configs/diff2compound.json', 'r') as f:
167
- config_data = json.load(f)
168
-
169
- prompt = create_prompt(config_data)
170
-
171
- inputs = self.tokenizer(inputs, return_tensors="pt")
172
- input_ids = inputs["input_ids"].to('cuda')
173
-
174
- ### Generation config https://huggingface.co/blog/how-to-generate
175
- generation_config = GenerationConfig(**parameters,
176
- pad_token_id=self.tokenizer.pad_token_id, num_return_sequences=1)
177
-
178
- max_new_tokens = self.model.config.max_seq_len - len(input_ids[0]) # max_new_tokens = 560 - len(input_ids[0])
179
-
180
- torch.manual_seed(137)
181
-
182
- with torch.no_grad():
183
- generation_output = self.model.generate(
184
- input_ids=input_ids,
185
- generation_config=generation_config,
186
- return_dict_in_generate=True,
187
- output_scores=True,
188
- max_new_tokens=max_new_tokens
189
- )
190
- if mode =='diff2compound':
191
- predicted_compounds = get_predicted_compounds(input_ids=input_ids, generation_output=generation_output, tokenizer=self.tokenizer, p3_compounds=self.unique_compounds_p3)
192
- output = {'output': predicted_compounds, "mode": mode, 'message': "Done!"}
193
- else:
194
- output = {'output': [None], "mode": mode, 'message': "Set mode"}
195
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import transformers
8
  import pandas as pd
9
  import time
10
+ import numpy as np
11
+ from precious3_gpt_multi_modal import Custom_MPTForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class EndpointHandler:
15
  def __init__(self, path=""):
16
+
17
+ self.path = path
18
  # load model and processor from path
19
  self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
20
  self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
 
24
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
25
  self.model.config.bos_token_id = self.tokenizer.bos_token_id
26
  self.model.config.eos_token_id = self.tokenizer.eos_token_id
27
+ unique_entities_p3 = pd.read_csv(os.path.join(path, 'p3_entities_with_type.csv'))
28
  self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
29
+
30
+ self.emb_gpt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_gpt_genes.pickle'))
31
+ self.emb_hgt_genes = pd.read_pickle(os.path.join(self.path, 'multi-modal-data/emb_hgt_genes.pickle'))
32
+
33
+
34
+
35
+ def custom_generate(self,
36
+ input_ids,
37
+ acc_embs_up_kg_mean,
38
+ acc_embs_down_kg_mean,
39
+ acc_embs_up_txt_mean,
40
+ acc_embs_down_txt_mean,
41
+ device,
42
+ max_new_tokens,
43
+ unique_compounds,
44
+ temperature=0.8,
45
+ top_p=0.2, top_k=3550,
46
+ n_next_tokens=50, num_return_sequences=1):
47
+ torch.manual_seed(137)
48
+
49
+ # Set parameters
50
+ # temperature - Higher value for more randomness, lower for more control
51
+ # top_p - Probability threshold for nucleus sampling (aka top-p sampling)
52
+ # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
53
+ # n_next_tokens - Number of top next tokens when predicting compounds
54
+
55
+ modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device)
56
+ modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device)
57
+ modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device)
58
+ modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device)
59
+
60
+
61
+ # Generate sequences
62
+ outputs = []
63
+ next_token_compounds = []
64
+
65
+ for _ in range(num_return_sequences):
66
+ start_time = time.time()
67
+ generated_sequence = []
68
+ current_token = input_ids.clone()
69
+
70
+ for _ in range(max_new_tokens): # Maximum length of generated sequence
71
+ # Forward pass through the model
72
+ logits = self.model.forward(input_ids=current_token,
73
+ modality0_emb=modality0_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
74
+ modality0_token_id=62191,
75
+ modality1_emb=modality1_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
76
+ modality1_token_id=62192,
77
+ modality2_emb=modality2_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
78
+ modality2_token_id=62193,
79
+ modality3_emb=modality3_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
80
+ modality3_token_id=62194)[0]
81
+
82
+ # Apply temperature to logits
83
+ if temperature != 1.0:
84
+ logits = logits / temperature
85
+
86
+ # Apply top-p sampling (nucleus sampling)
87
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
88
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
89
+ sorted_indices_to_remove = cumulative_probs > top_p
90
+
91
+ if top_k > 0:
92
+ sorted_indices_to_remove[..., top_k:] = 1
93
+
94
+ # Set the logit values of the removed indices to a very small negative value
95
+ inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
96
+
97
+ logits = logits.where(sorted_indices_to_remove, inf_tensor)
98
+
99
+
100
+ # Sample the next token
101
+ if current_token[0][-1] == self.tokenizer.encode('<drug>')[0]:
102
+ next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), n_next_tokens).indices)
103
+
104
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
105
+
106
+
107
+ # Append the sampled token to the generated sequence
108
+ generated_sequence.append(next_token.item())
109
+
110
+ # Stop generation if an end token is generated
111
+ if next_token == self.tokenizer.eos_token_id:
112
+ break
113
+
114
+ # Prepare input for the next iteration
115
+ current_token = torch.cat((current_token, next_token), dim=-1)
116
+ print(time.time()-start_time)
117
+ outputs.append(generated_sequence)
118
+
119
+ predicted_compounds_ids = [self.tokenizer.convert_ids_to_tokens(j) for j in next_token_compounds]
120
+ predicted_compounds = []
121
+ for j in predicted_compounds_ids:
122
+ predicted_compounds.append([i.strip() for i in j])
123
+ return outputs, predicted_compounds
124
 
 
125
 
126
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
127
  """
 
129
  data (:dict:):
130
  The payload with the text prompt and generation parameters.
131
  """
132
+ torch.manual_seed(137)
133
+
134
+ device = "cuda"
135
+ prompt = data.pop("inputs", None)
136
  parameters = data.pop("parameters", None)
137
  mode = data.pop('mode', 'diff2compound')
138
 
139
+ inputs = self.tokenizer(prompt, return_tensors="pt")
140
+ input_ids = inputs["input_ids"].to(device)
141
+
142
+ max_new_tokens = 600 - len(input_ids[0])
143
+ try:
144
+ acc_embs_up1 = []
145
+ acc_embs_up2 = []
146
+ for gs in config_data['up']:
147
+ try:
148
+ acc_embs_up1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
149
+ acc_embs_up2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
150
+ except Exception as e:
151
+ pass
152
+ acc_embs_up1_mean = np.array(acc_embs_up1).mean(0)
153
+ acc_embs_up2_mean = np.array(acc_embs_up2).mean(0)
154
+
155
+ acc_embs_down1 = []
156
+ acc_embs_down2 = []
157
+ for gs in config_data['down']:
158
+ try:
159
+ acc_embs_down1.append(self.emb_hgt_genes[self.emb_hgt_genes.gene_symbol==gs].embs.values[0])
160
+ acc_embs_down2.append(self.emb_gpt_genes[self.emb_gpt_genes.gene_symbol==gs].embs.values[0])
161
+ except Exception as e:
162
+ pass
163
+ acc_embs_down1_mean = np.array(acc_embs_down1).mean(0)
164
+ acc_embs_down2_mean = np.array(acc_embs_down2).mean(0)
165
+
166
+ generated_sequence, raw_next_token_generation = self.custom_generate(input_ids = input_ids,
167
+ acc_embs_up_kg_mean=acc_embs_up1_mean,
168
+ acc_embs_down_kg_mean=acc_embs_down1_mean,
169
+ acc_embs_up_txt_mean=acc_embs_up2_mean,
170
+ acc_embs_down_txt_mean=acc_embs_down2_mean, max_new_tokens=max_new_tokens,
171
+ device=device, unique_compounds=self.unique_compounds_p3, **parameters)
172
+ next_token_generation = [sorted(set(i) & set(self.unique_compounds_p3), key = i.index) for i in raw_next_token_generation]
173
+
174
+ if mode == "meta2diff":
175
+ outputs = {"up": generated_sequence, "down": generated_sequence}
176
+ else:
177
+ outputs = generated_sequence
178
+ out = {"output": outputs, 'compounds': next_token_generation, "raw_output": raw_next_token_generation, "mode": mode, 'message': "Done!"}
179
+
180
+ except Exception as e:
181
+ print(e)
182
+ outputs, next_token_generation = [None], [None]
183
+ out = {"output": outputs, "mode": mode, 'message': f"{e}"}
184
+
185
+ return out