temp_len: 5217, output_imgs[-1].shape[1]: 5233

#74
by HERIUN - opened

I tested this model on mmmu_val

on validation set (all 900). there is error near index 740(no shuffle)

example["id] == "validation_Music_21"

from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    device_map="cuda", 
    torch_dtype="auto", 
    trust_remote_code=True,
    # if you do not use Ampere or later GPUs, change attention to "eager"
    _attn_implementation='flash_attention_2',
).cuda()

data_subsets = [
    "Accounting",
    "Agriculture",
    "Architecture_and_Engineering",
    "Art",
    "Art_Theory",
    "Basic_Medical_Science",
    "Biology",
    "Chemistry",
    "Clinical_Medicine",
    "Computer_Science",
    "Design",
    "Diagnostics_and_Laboratory_Medicine",
    "Economics",
    "Electronics",
    "Energy_and_Power",
    "Finance",
    "Geography",
    "History",
    "Literature",
    "Manage",
    "Marketing",
    "Materials",
    "Math",
    "Mechanical_Engineering",
    "Music",
    "Pharmacy",
    "Physics",
    "Psychology",
    "Public_Health",
    "Sociology"
  ]
data_split =  "validation"
loaded_dataset = []
for subset_name in data_subsets:
    loaded_dataset.append(load_dataset("mmmu/mmmu", name=subset_name, split="validation"))

dataset = concatenate_datasets(loaded_dataset)
dataset = dataset.skip(740)


def phi4mminst_mmmu_preprocess(example):
    messages = []

    def transform_image_tags(text):
        """
        ํ…์ŠคํŠธ ๋‚ด์˜ "<image ์ˆซ์ž>" ํ˜•ํƒœ์˜ ํƒœ๊ทธ๋ฅผ "<image_์ˆซ์ž>" ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
        ์˜ˆ: "<image 1>" -> "<|image_1|>", "<image 23>" -> "<|image_23|>"
        """
        # ํŒจํ„ด ์„ค๋ช…:
        # <image : ๋ฌธ์ž์—ด "<image "์™€ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค.
        # (\d+)  : ๊ด„ํ˜ธ ์•ˆ์€ ์บก์ฒ˜ ๊ทธ๋ฃน์ž…๋‹ˆ๋‹ค. \d+๋Š” ํ•˜๋‚˜ ์ด์ƒ์˜ ์ˆซ์ž์™€ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค (์˜ˆ: "1", "12", "05").
        # >      : ๋ฌธ์ž์—ด ">"์™€ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค.
        pattern = r"<image (\d+)>"
        
        # ์น˜ํ™˜๋  ๋ฌธ์ž์—ด ์„ค๋ช…:
        # <image_ : ๋ฌธ์ž์—ด "<image_"๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
        # \1      : ์ฒซ ๋ฒˆ์งธ ์บก์ฒ˜ ๊ทธ๋ฃน(\d+)์˜ ๋‚ด์šฉ (์ฆ‰, ์ˆซ์ž)์„ ์ฐธ์กฐํ•ฉ๋‹ˆ๋‹ค.
        # >       : ๋ฌธ์ž์—ด ">"๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
        replacement = r"<|image_\1|>"
        
        transformed_text, count = re.subn(pattern, replacement, text)
        return transformed_text, count

    question = example["question"]

    options = {cand: eval(example["options"])[i] for i,cand in enumerate(string.ascii_uppercase[:len(eval(example["options"]))])}
    options_prompt = 'Options:\n'
    for key, item in options.items():
        options_prompt += f'{key}. {item}\n'

    prompt = ''
    prompt += f'Question: {question}\n'
    if len(options):
        prompt += options_prompt
        prompt += 'Please select the correct answer from the options above. \n'
    prompt = prompt.rstrip()
    
    preprocessed_question_text, image_cnt_in_question = transform_image_tags(prompt)
    prompt = f'{preprocessed_question_text}'
    messages.append({'role': 'user', 'content': prompt})
    
    images = []
    def find_image_tokens(text):
        """์ฃผ์–ด์ง„ ๋ฌธ์ž์—ด์—์„œ '<|image_n|>' ํŒจํ„ด์„ ์ฐพ์•„ ๋ฆฌ์ŠคํŠธ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

        Args:
            text: ๊ฒ€์ƒ‰ํ•  ๋ฌธ์ž์—ด์ž…๋‹ˆ๋‹ค.

        Returns:
            ์ฐพ์€ '<|image_n|>' ํ† ํฐ์˜ ๋ฆฌ์ŠคํŠธ์ž…๋‹ˆ๋‹ค.
        """
        pattern = r"<\|image_\d+\|>"
        matches = re.findall(pattern, text)
        return matches

    for key in find_image_tokens(prompt):
        images.append(example[key.replace("<|", "").replace("|>", "")])
    
    return {
        "messages" : messages,
        "images" : images,
    }


for i, example in tqdm(enumerate(dataset), total=len(dataset), desc="inference & evaluate"):
        input = phi4mminst_mmmu_preprocess(example)
        input["text"] = processor.apply_chat_template(
            input["messages"],
            tokenize=False,
            add_generation_prompt=True,
        )
        del input["messages"]
        inputs = processor(**input, padding=True, return_tensors='pt').to(model.device, model.dtype)
        outputs = model.generate(
            **inputs,
            **generate_kwargs
        )

Sign up or log in to comment