AryaWu commited on
Commit
67797ad
1 Parent(s): e12a8f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -99
app.py CHANGED
@@ -17,8 +17,13 @@ CONFIG.set_default_api_key(api_key)
17
 
18
  access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
19
 
20
- # Load the Language Model
21
- llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B", token=access_token)
 
 
 
 
 
22
 
23
  #placeholder for reset
24
  prompts_with_probs = pd.DataFrame(
@@ -53,7 +58,7 @@ def run_lens(model,PROMPT):
53
  logits_lens_probs_by_layer.append(logits_lens_probs)
54
  logits_lens_next_token = torch.argmax(logits_lens_probs, dim=1).save()
55
  logits_lens_token_result_by_layer.append(logits_lens_next_token)
56
- tokens_out = llama.lm_head.output.argmax(dim=-1).save()
57
  expected_token = tokens_out[0][-1].save()
58
  # logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer])
59
  logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
@@ -65,7 +70,7 @@ def run_lens(model,PROMPT):
65
  # Find the rank of the expected token (1-based rank)
66
  expected_token_rank = (sorted_indices == expected_token).nonzero(as_tuple=True)[1].item() + 1
67
  logits_lens_ranks_by_layer.append(expected_token_rank)
68
- actual_output = llama.tokenizer.decode(expected_token.item())
69
  logits_lens_results = [model.tokenizer.decode(next_token.item()) for next_token in logits_lens_token_result_by_layer]
70
  return logits_lens_results, logits_lens_all_probs, actual_output,logits_lens_ranks_by_layer
71
 
@@ -98,28 +103,35 @@ def process_file(prompts_data,file_path):
98
 
99
  def plot_prob(prompts_with_probs):
100
  plt.figure(figsize=(10, 6))
101
-
 
102
  # Iterate over each prompt and plot its probabilities
103
  for prompt in prompts_with_probs['prompt'].unique():
104
  # Filter the DataFrame for the current prompt
105
  prompt_data = prompts_with_probs[prompts_with_probs['prompt'] == prompt]
 
106
 
107
  # Plot probabilities for this prompt
108
- plt.plot(prompt_data['layer'], prompt_data['probs'], marker='x', label=prompt)
109
-
110
  # Annotate each point with the corresponding result
111
  for layer, prob, result in zip(prompt_data['layer'], prompt_data['probs'], prompt_data['results']):
112
- plt.text(layer, prob, result, fontsize=8)
113
-
114
 
115
  # Add labels and title
116
  plt.xlabel('Layer Number')
117
- plt.ylabel('Probability of Expected Token')
118
- plt.title('Prob of expected token across layers\n(annotated with actual decoded output at each layer)')
119
  plt.grid(True)
 
120
  plt.ylim(0.0, 1.0)
121
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
122
 
 
 
 
 
123
  # Save the plot to a buffer
124
  buf = io.BytesIO()
125
  plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels
@@ -130,27 +142,34 @@ def plot_prob(prompts_with_probs):
130
 
131
  def plot_rank(prompts_with_ranks):
132
  plt.figure(figsize=(10, 6))
133
-
 
134
  # Iterate over each prompt and plot its ranks
135
  for prompt in prompts_with_ranks['prompt'].unique():
136
  # Filter the DataFrame for the current prompt
137
  prompt_data = prompts_with_ranks[prompts_with_ranks['prompt'] == prompt]
 
138
 
139
  # Plot ranks for this prompt
140
- plt.plot(prompt_data['layer'], prompt_data['ranks'], marker='x', label=prompt)
141
-
142
  # Annotate each point with the corresponding result
143
  for layer, rank, result in zip(prompt_data['layer'], prompt_data['ranks'], prompt_data['results']):
144
- plt.text(layer, rank,result, ha='right', va='bottom', fontsize=8)
 
145
 
146
  # Add labels and title
147
  plt.xlabel('Layer Number')
148
- plt.ylabel('Rank of Expected Token')
149
- plt.title('Rank of expected token across layers\n(annotated with decoded output at each layer)')
150
  plt.grid(True)
 
151
  plt.ylim(bottom=0) # Adjust if needed, depending on your rank values
152
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
153
 
 
 
 
154
 
155
  # Save the plot to a buffer
156
  buf = io.BytesIO()
@@ -160,77 +179,8 @@ def plot_rank(prompts_with_ranks):
160
  plt.close() # Close the figure to free memory
161
  return img
162
 
163
- def plot_prob_mean(prompts_with_probs):
164
- # Calculate mean probabilities and variance
165
- summary_stats = prompts_with_probs.groupby("prompt")["probs"].agg(
166
- mean_prob="mean",
167
- variance="var"
168
- ).reset_index()
169
-
170
- # Set up the bar plot
171
- plt.figure(figsize=(10, 6))
172
- bars = plt.bar(summary_stats['prompt'], summary_stats['mean_prob'],
173
- yerr=summary_stats['variance']**0.5, # Error bars are the standard deviation
174
- capsize=5, color='skyblue')
175
-
176
- # Add labels and title
177
- plt.xlabel('Prompt')
178
- plt.ylabel('Mean Probability')
179
- plt.title('Mean Probability of Expected Token')
180
- plt.xticks(rotation=45, ha='right')
181
- plt.grid(axis='y')
182
- plt.ylim(0, 1)
183
-
184
-
185
- # Annotate the mean and variance on the bars
186
- for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
187
- yval = bar.get_height()
188
- plt.text(bar.get_x() + bar.get_width() / 2, yval, f'Mean: {mean:.2f}\nVar: {var:.2f}',
189
- ha='center', va='bottom', fontsize=8, color='black')
190
-
191
- # Save the plot to a buffer
192
- buf = io.BytesIO()
193
- plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels
194
- buf.seek(0)
195
- img = Image.open(buf)
196
- plt.close() # Close the figure to free memory
197
- return img
198
-
199
- def plot_rank_mean(prompts_with_ranks):
200
- # Calculate mean ranks and variance
201
- summary_stats = prompts_with_ranks.groupby("prompt")["ranks"].agg(
202
- mean_rank="mean",
203
- variance="var"
204
- ).reset_index()
205
-
206
- # Set up the bar plot
207
- plt.figure(figsize=(10, 6))
208
- bars = plt.bar(summary_stats['prompt'], summary_stats['mean_rank'],
209
- yerr=summary_stats['variance']**0.5, # Error bars are the standard deviation
210
- capsize=5, color='salmon')
211
-
212
- # Add labels and title
213
- plt.xlabel('Prompt')
214
- plt.ylabel('Mean Rank')
215
- plt.title('Mean Rank of Expected Token')
216
- plt.xticks(rotation=45, ha='right')
217
- plt.grid(axis='y')
218
-
219
- # Annotate the mean and variance on the bars
220
- for bar, mean, var in zip(bars, summary_stats['mean_rank'], summary_stats['variance']):
221
- yval = bar.get_height()
222
- plt.text(bar.get_x() + bar.get_width() / 2, yval, f'Mean: {mean:.2f}\nVar: {var:.2f}',
223
- ha='center', va='bottom', fontsize=8, color='black')
224
-
225
- # Save the plot to a buffer
226
- buf = io.BytesIO()
227
- plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels
228
- buf.seek(0)
229
- img = Image.open(buf)
230
- plt.close() # Close the figure to free memory
231
- return img
232
-
233
- def submit_prompts(prompts_data):
234
  # Initialize lists to accumulate results
235
  all_prompts = []
236
  all_results = []
@@ -276,20 +226,23 @@ def submit_prompts(prompts_data):
276
  "ranks": all_ranks,
277
  "expected": all_expected,
278
  })
279
- return plot_prob(prompts_with_probs), plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
280
 
281
  def clear_all(prompts):
282
  prompts=[['']]
283
  # prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
284
  prompt_file = None
285
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
286
- return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
287
 
288
  def gradio_interface():
289
  with gr.Blocks(theme="gradio/monochrome") as demo:
290
- prompts=[['']]
 
 
291
  with gr.Row():
292
  with gr.Column(scale=3):
 
293
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
294
  with gr.Column(scale=1):
295
  prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
@@ -298,18 +251,30 @@ def gradio_interface():
298
  with gr.Row():
299
  clear_btn = gr.Button("Clear")
300
  submit_btn = gr.Button("Submit")
 
 
 
 
 
 
301
  with gr.Row():
302
- prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
303
- rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
 
 
 
 
 
304
  with gr.Row():
305
- prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
306
- rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
307
 
308
- clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
309
- submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
310
- prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
 
 
311
 
312
-
313
  demo.launch()
314
 
315
  gradio_interface()
 
17
 
18
  access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
19
 
20
+
21
+ # Model options
22
+ MODEL_OPTIONS = {
23
+ "Llama3.1-8B": "meta-llama/Meta-Llama-3.1-8B",
24
+ "Llama3.1-70B": "meta-llama/Meta-Llama-3.1-70B",
25
+ }
26
+
27
 
28
  #placeholder for reset
29
  prompts_with_probs = pd.DataFrame(
 
58
  logits_lens_probs_by_layer.append(logits_lens_probs)
59
  logits_lens_next_token = torch.argmax(logits_lens_probs, dim=1).save()
60
  logits_lens_token_result_by_layer.append(logits_lens_next_token)
61
+ tokens_out = model.lm_head.output.argmax(dim=-1).save()
62
  expected_token = tokens_out[0][-1].save()
63
  # logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer])
64
  logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
 
70
  # Find the rank of the expected token (1-based rank)
71
  expected_token_rank = (sorted_indices == expected_token).nonzero(as_tuple=True)[1].item() + 1
72
  logits_lens_ranks_by_layer.append(expected_token_rank)
73
+ actual_output = model.tokenizer.decode(expected_token.item())
74
  logits_lens_results = [model.tokenizer.decode(next_token.item()) for next_token in logits_lens_token_result_by_layer]
75
  return logits_lens_results, logits_lens_all_probs, actual_output,logits_lens_ranks_by_layer
76
 
 
103
 
104
  def plot_prob(prompts_with_probs):
105
  plt.figure(figsize=(10, 6))
106
+ texts = [] # List to hold text annotations for adjustment
107
+
108
  # Iterate over each prompt and plot its probabilities
109
  for prompt in prompts_with_probs['prompt'].unique():
110
  # Filter the DataFrame for the current prompt
111
  prompt_data = prompts_with_probs[prompts_with_probs['prompt'] == prompt]
112
+ label = f"{prompt}({prompt_data['expected'].iloc[0]})"
113
 
114
  # Plot probabilities for this prompt
115
+ plt.plot(prompt_data['layer'], prompt_data['probs'], marker='x', label=label)
116
+
117
  # Annotate each point with the corresponding result
118
  for layer, prob, result in zip(prompt_data['layer'], prompt_data['probs'], prompt_data['results']):
119
+ text = plt.text(layer, prob, result, fontsize=8)
120
+ texts.append(text) # Add text to the list
121
 
122
  # Add labels and title
123
  plt.xlabel('Layer Number')
124
+ plt.ylabel('Probability')
125
+ plt.title('Probability of most-likely output token')
126
  plt.grid(True)
127
+ plt.xlim(0,max(prompts_with_probs['layer']))
128
  plt.ylim(0.0, 1.0)
129
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
130
 
131
+ # Adjust text to prevent overlap
132
+ adjust_text(texts, only_move={'points': 'xy', 'texts': 'xy'},
133
+ arrowprops=dict(arrowstyle="->", color='r', lw=0.5))
134
+
135
  # Save the plot to a buffer
136
  buf = io.BytesIO()
137
  plt.savefig(buf, format='png', bbox_inches='tight') # Use bbox_inches to avoid cutting off labels
 
142
 
143
  def plot_rank(prompts_with_ranks):
144
  plt.figure(figsize=(10, 6))
145
+ texts = [] # List to hold text annotations for adjustment
146
+
147
  # Iterate over each prompt and plot its ranks
148
  for prompt in prompts_with_ranks['prompt'].unique():
149
  # Filter the DataFrame for the current prompt
150
  prompt_data = prompts_with_ranks[prompts_with_ranks['prompt'] == prompt]
151
+ label = f"{prompt}({prompt_data['expected'].iloc[0]})"
152
 
153
  # Plot ranks for this prompt
154
+ plt.plot(prompt_data['layer'], prompt_data['ranks'], marker='x', label=label)
155
+
156
  # Annotate each point with the corresponding result
157
  for layer, rank, result in zip(prompt_data['layer'], prompt_data['ranks'], prompt_data['results']):
158
+ text = plt.text(layer, rank, result, ha='right', va='bottom', fontsize=8)
159
+ texts.append(text) # Add text to the list
160
 
161
  # Add labels and title
162
  plt.xlabel('Layer Number')
163
+ plt.ylabel('Rank')
164
+ plt.title('Rank of most-likely output token')
165
  plt.grid(True)
166
+ plt.xlim(0,max(prompts_with_ranks['layer']))
167
  plt.ylim(bottom=0) # Adjust if needed, depending on your rank values
168
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
169
 
170
+ # Adjust text to prevent overlap
171
+ adjust_text(texts,only_move={'points': 'xy', 'texts': 'xy'},
172
+ arrowprops=dict(arrowstyle="->", color='r', lw=0.5))
173
 
174
  # Save the plot to a buffer
175
  buf = io.BytesIO()
 
179
  plt.close() # Close the figure to free memory
180
  return img
181
 
182
+ def submit_prompts(model_name, prompts_data):
183
+ llama = LanguageModel(MODEL_OPTIONS[model_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Initialize lists to accumulate results
185
  all_prompts = []
186
  all_results = []
 
226
  "ranks": all_ranks,
227
  "expected": all_expected,
228
  })
229
+ return plot_prob(prompts_with_probs), plot_rank(prompts_with_ranks)
230
 
231
  def clear_all(prompts):
232
  prompts=[['']]
233
  # prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
234
  prompt_file = None
235
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
236
+ return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks)
237
 
238
  def gradio_interface():
239
  with gr.Blocks(theme="gradio/monochrome") as demo:
240
+ prompts = [['The Eiffel Tower is located in the city of'],['Vatican is located in the city of']]
241
+
242
+ # prompts=[['']]
243
  with gr.Row():
244
  with gr.Column(scale=3):
245
+ model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Llama3.1-8B")
246
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
247
  with gr.Column(scale=1):
248
  prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
 
251
  with gr.Row():
252
  clear_btn = gr.Button("Clear")
253
  submit_btn = gr.Button("Submit")
254
+
255
+ prompt_file.upload(process_file, inputs=[prompts_data, prompt_file], outputs=[prompts_data])
256
+
257
+
258
+ gr.Markdown("The most likely output token is the model's prediction at the final layer, shown in brackets in the plot legend.")
259
+ # Create a Markdown component for the description
260
  with gr.Row():
261
+ gr.Markdown("The graph below illustrates the probability of this most likely output token as it is decoded at each layer of the model. Each point on the graph is annotated with the decoded output corresponding to the token that has the highest probability at that particular layer.")
262
+ gr.Markdown("The graph below illustrates the rank of this most likely output token as it is decoded at each layer of the model. Each point on the graph is annotated with the decoded output corresponding to the token that has the lowest rank at that particular layer.")
263
+
264
+ prob_img, rank_img = submit_prompts(model_dropdown.value, prompts)
265
+ # prob_visualization.value = prob_img # Direct assignment to value
266
+ # rank_visualization.value = rank_img # Direct assignment to value
267
+
268
  with gr.Row():
269
+ prob_visualization = gr.Image(value=prob_img, type="pil",label=" ")
270
+ rank_visualization = gr.Image(value=rank_img, type="pil",label=" ")
271
 
272
+ clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization])
273
+ submit_btn.click(submit_prompts, inputs=[model_dropdown,prompts_data], outputs=[prob_visualization,rank_visualization])#
274
+ prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization])
275
+
276
+ # Generate plots with sample prompts on load
277
 
 
278
  demo.launch()
279
 
280
  gradio_interface()