|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import os |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.utils import disable_torch_init |
|
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria |
|
from llava.model import * |
|
|
|
import json |
|
from PIL import Image |
|
import os |
|
import requests |
|
from PIL import Image, ImageDraw, ImageFont |
|
from io import BytesIO |
|
from tqdm import tqdm |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.ndimage.filters import gaussian_filter |
|
import argparse |
|
import datasets |
|
from llava.model import LlavaLlamaForCausalLM |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.model.builder import load_pretrained_model |
|
from llava.utils import disable_torch_init |
|
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria,process_images |
|
from transformers.generation.streamers import TextIteratorStreamer |
|
|
|
from PIL import Image |
|
|
|
import requests |
|
from io import BytesIO |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig |
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
|
|
|
|
from transformers import LlamaModel, LlamaForCausalLM |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.generation.utils import GenerateOutput |
|
|
|
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
|
|
import time |
|
import subprocess |
|
from threading import Thread |
|
|
|
import os |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_path', type=str,default="/jizhicfs/bojoli/mmpe/mmpe-main/checkpoints/mmpe_finetune_vicuna-7b-1.5_clip-vit-large-patch14-336/") |
|
parser.add_argument('--image', type=str) |
|
parser.add_argument('--prompt', type=str, default="Describe this image.") |
|
parser.add_argument('--output', type=str,default='attention_map/mmpe') |
|
parser.add_argument('--layer', type=int, default=32) |
|
parser.add_argument('--w', action='store_true', help="Enable some feature") |
|
parser.add_argument('--position', type=int, default=0) |
|
parser.add_argument('--target_text', type=str, default=None) |
|
args = parser.parse_args() |
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
|
|
|
|
|
|
|
|
|
|
|
|
from PIL import Image |
|
|
|
def new_forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
image_sizes: Optional[List[List[int]]] = None, |
|
return_dict: Optional[bool] = None, |
|
modalities: Optional[List[str]] = ["image"], |
|
dpo_forward: Optional[bool] = None, |
|
cache_position=None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
if inputs_embeds is None: |
|
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) |
|
|
|
if dpo_forward: |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
return logits, labels |
|
|
|
else: |
|
return position_ids,LlavaLlamaForCausalLM.forward( |
|
self, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if __name__ == "__main__": |
|
dataset_path = "/jizhicfs/bojoli/dataset/mmbench/en" |
|
args = parser.parse_args() |
|
if args.w: |
|
args.model_path = "/jizhicfs/bojoli/mmpe/mmpe-main/checkpoints/final_mmpe_finetune_vicuna-7b-1.5_clip-vit-large-patch14-336/" |
|
args.output = "attention_map/mmpe" |
|
else: |
|
args.model_path = "/jizhicfs/bojoli/mmpe/mmpe-main/checkpoints/final_without_mmpe_finetune_vicuna-7b-1.5_clip-vit-large-patch14-336/" |
|
args.output = "attention_map/without_mmpe" |
|
mmbench_data = datasets.load_dataset(dataset_path, split='validation') |
|
|
|
for i in range(len(mmbench_data)): |
|
if mmbench_data[i]['question'] == 'Think about the magnetic force between the magnets in each pair. Which of the following statements is true?': |
|
print(mmbench_data[i]) |
|
break |
|
|
|
width, height = mmbench_data[i]['image'].size |
|
mmbench_data[i]['image'].save('example.jpg') |
|
disable_torch_init() |
|
|
|
model = LlavaLlamaForCausalLM.from_pretrained(args.model_path).cuda() |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
|
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower) |
|
|
|
conv_mode = "llava_v1" |
|
conv = conv_templates[conv_mode].copy() |
|
|
|
image_data = [mmbench_data[i]['image']] |
|
image_tensor = process_images(image_data, image_processor, model.config).cuda() |
|
|
|
inp = DEFAULT_IMAGE_TOKEN + "\n" + mmbench_data[i]['question'] |
|
conv.append_message(conv.roles[0], inp) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() |
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
keywords = [stop_str] |
|
model.forward = new_forward.__get__(model, LlavaLlamaForCausalLM) |
|
image_size_list = torch.Tensor([[width, height]]).cuda() |
|
|
|
with torch.no_grad(): |
|
position_ids,output = model( |
|
input_ids, |
|
images=image_tensor, |
|
image_sizes=image_size_list, |
|
output_attentions=True, |
|
return_dict=True |
|
) |
|
|
|
print(position_ids) |
|
|
|
list=[] |
|
for j in position_ids[0]: |
|
list.append(j.item()) |
|
print(list) |
|
|
|
|
|
target_index=36+574 |
|
flag=False |
|
target_position = args.position |
|
if target_position!=0: |
|
for j in range(0,position_ids.shape[1]): |
|
if list[j]==target_position: |
|
if flag==False: |
|
flag=True |
|
else: |
|
target_position=j |
|
break |
|
num_layer=len(output.attentions) |
|
input_ids_list = input_ids[0].cpu().tolist() |
|
|
|
instruction_begin_index = input_ids_list.index(IMAGE_TOKEN_INDEX) |
|
len_instruc = len(input_ids[0]) - instruction_begin_index - 1 |
|
begin_2=612 |
|
end_2=position_ids.shape[1]-len_instruc |
|
target_position=[target_position] |
|
img = mmbench_data[i]['image'] |
|
orig_width, orig_height = img.size |
|
img = img.resize((336, 336), Image.BILINEAR) |
|
|
|
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
position_region_length = end_2 - begin_2 |
|
|
|
|
|
|
|
if aspect_ratio >= 1.0: |
|
grid_width = int(np.sqrt(position_region_length * aspect_ratio)) |
|
grid_height = int(position_region_length / grid_width) |
|
|
|
|
|
while grid_width * grid_height < position_region_length: |
|
grid_height += 1 |
|
else: |
|
|
|
grid_height = int(np.sqrt(position_region_length / aspect_ratio)) |
|
grid_width = int(position_region_length / grid_height) |
|
|
|
|
|
while grid_width * grid_height < position_region_length: |
|
grid_width += 1 |
|
|
|
position_grid = np.zeros((grid_height, grid_width), dtype=int) |
|
position_img = Image.fromarray(position_grid.astype(np.uint8) * 255 // (end_2 - begin_2)) |
|
position_img = position_img.resize((336, 336), Image.BILINEAR) |
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(img) |
|
|
|
|
|
highlighted_position = 950 |
|
highlighted_x = None |
|
highlighted_y = None |
|
|
|
for y in range(grid_height): |
|
for x in range(grid_width): |
|
if y * grid_width + x < position_region_length: |
|
pos_id = position_grid[y, x] |
|
|
|
|
|
img_x = int(x * 336 / grid_width + 336 / (2 * grid_width)) |
|
img_y = int(y * 336 / grid_height + 336 / (2 * grid_height)) |
|
|
|
|
|
if pos_id == highlighted_position: |
|
highlighted_x = img_x |
|
highlighted_y = img_y |
|
|
|
|
|
cell_width = 336 / grid_width |
|
cell_height = 336 / grid_height |
|
|
|
if highlighted_x is not None and highlighted_y is not None: |
|
|
|
circle_radius = min(336 / grid_width, 336 / grid_height) / 2 |
|
|
|
|
|
circle = plt.Circle((highlighted_x, highlighted_y), circle_radius, |
|
edgecolor='lime', facecolor='none', linewidth=3) |
|
plt.gca().add_patch(circle) |
|
|
|
|
|
plt.text(highlighted_x, highlighted_y - circle_radius - 5, |
|
f"Position {highlighted_position}", |
|
ha='center', va='center', color='lime', |
|
fontweight='bold', fontsize=12, |
|
bbox=dict(facecolor='black', alpha=0.7, pad=1)) |
|
else: |
|
print(f"警告:在网格中找不到位置 {highlighted_position}") |
|
|
|
plt.axis('off') |
|
plt.title("Highlighted Position 950 on Image") |
|
os.makedirs(args.output, exist_ok=True) |
|
plt.savefig(f"{args.output}/position_950_highlight.png") |
|
plt.close() |
|
for idx, pos_id in enumerate(range(begin_2, end_2)): |
|
if idx < grid_width * grid_height: |
|
row = idx // grid_width |
|
col = idx % grid_width |
|
position_grid[row, col] = pos_id |
|
|
|
|
|
position_img = Image.fromarray(position_grid.astype(np.uint8) * 255 // (end_2 - begin_2)) |
|
position_img = position_img.resize((336, 336), Image.BILINEAR) |
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(img) |
|
|
|
for y in range(grid_height): |
|
for x in range(grid_width): |
|
if y * grid_width + x < position_region_length: |
|
pos_id = position_grid[y, x] |
|
|
|
|
|
img_x = int(x * 336 / grid_width + 336 / (2 * grid_width)) |
|
img_y = int(y * 336 / grid_height + 336 / (2 * grid_height)) |
|
|
|
|
|
cell_width = 336 / grid_width |
|
cell_height = 336 / grid_height |
|
|
|
|
|
rect = plt.Rectangle((img_x - cell_width/2, img_y - cell_height/2), |
|
cell_width, cell_height, |
|
linewidth=1, edgecolor='white', facecolor='none', alpha=0.3) |
|
plt.gca().add_patch(rect) |
|
|
|
|
|
if grid_width <= 10 or (x % 3 == 0 and y % 3 == 0): |
|
plt.text(img_x, img_y, str(pos_id), |
|
ha='center', va='center', color='white', |
|
bbox=dict(facecolor='black', alpha=0.5, pad=1)) |
|
|
|
plt.axis('off') |
|
plt.title("Position IDs from begin_2 to end_2 overlaid on image") |
|
os.makedirs(args.output, exist_ok=True) |
|
plt.savefig(f"{args.output}/position_ids_overlay.png") |
|
plt.close() |
|
|
|
target_position = [948] |
|
if args.target_text is not None: |
|
target_position=[] |
|
target_tokens=tokenizer.tokenize(args.target_text) |
|
target_tokens_ids=tokenizer.convert_tokens_to_ids(target_tokens) |
|
for j in range(0,input_ids.shape[1]): |
|
if input_ids[0][j] in target_tokens_ids: |
|
target_position.append(j+len(position_ids[0])-len(input_ids[0])) |
|
print(f"target_position:{target_position}") |
|
for k in range(num_layer): |
|
attention = output.attentions[k].squeeze(0) |
|
|
|
|
|
if len(target_position) > 0: |
|
|
|
avg_attention = torch.zeros_like(attention[:, 0, 36:612]) |
|
for pos in target_position: |
|
avg_attention += attention[:, pos, 36:612] |
|
avg_attention = avg_attention / len(target_position) |
|
|
|
|
|
attention_target = avg_attention.mean(dim=0) |
|
else: |
|
|
|
attention_target = attention[:, target_position[0], 36:612].mean(dim=0) |
|
|
|
|
|
attention_target = torch.softmax(attention_target * 200, dim=0).view(24, 24) |
|
|
|
attention_target = np.array(attention_target.cpu(), dtype=np.float32) * 100 |
|
|
|
|
|
img = mmbench_data[i]['image'] |
|
img = img.resize((336, 336), Image.BILINEAR) |
|
print(type(img)) |
|
img.save('example.jpg') |
|
resized_attention = np.array(Image.fromarray((attention_target * 255).astype(np.uint8)).resize(img.size, resample=Image.BILINEAR)) |
|
smoothed_attention = gaussian_filter(resized_attention, sigma=2) |
|
|
|
|
|
plt.figure(figsize=(img.size[0] / 100, img.size[1] / 100)) |
|
sns.heatmap(smoothed_attention, cmap="jet", alpha=0.5, zorder=2) |
|
plt.imshow(img, aspect='auto', zorder=1) |
|
plt.axis('off') |
|
os.makedirs(args.output, exist_ok=True) |
|
plt.savefig(f"{args.output}/attn_layer{k}_{'_'.join(args.target_text.split()) if args.target_text else target_position[0]}.png") |
|
plt.close() |
|
|
|
|
|
print('done') |
|
|
|
|
|
|
|
|
|
|
|
|