File size: 3,458 Bytes
fc889e7
f2df7d1
fc889e7
f2df7d1
 
 
 
 
 
 
 
9017c29
 
 
 
f2df7d1
 
 
2eb7fd3
f2df7d1
 
 
fc889e7
 
 
 
bc0be6a
fc889e7
f2df7d1
 
 
 
 
fc889e7
f2df7d1
 
 
bc0be6a
 
f2df7d1
 
 
fc889e7
 
f2df7d1
fc889e7
f2df7d1
bc0be6a
5ad2922
 
 
bc0be6a
fc889e7
 
 
 
 
 
 
9017c29
5ad2922
f2df7d1
5ad2922
fc889e7
f2df7d1
5ad2922
fc889e7
 
 
5ad2922
 
 
fc889e7
 
 
 
 
5ad2922
 
fc889e7
bc0be6a
 
 
 
5ad2922
fc889e7
 
bc0be6a
9017c29
f2df7d1
bc0be6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

import io
from functools import lru_cache
from typing import Optional

import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image

from transformers.utils.processor_visualizer_utils import ImageVisualizer

MODELS = [
    "openai/clip-vit-base-patch32",
    "HuggingFaceM4/Idefics3-8B-Llama3",
]

def _fig_to_pil(fig) -> Image.Image:
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=160)
    buf.seek(0)
    return Image.open(buf).convert("RGB")

@lru_cache(maxsize=64)
def get_viz(model_id: str) -> ImageVisualizer:
    return ImageVisualizer(model_id)

def _run(model_id: str, image: Optional[Image.Image], add_grid: bool):
    viz = get_viz(model_id)

    captured = []
    orig_show = plt.show

    def _capture_show(*_, **__):
        captured.append(plt.gcf())

    try:
        plt.show = _capture_show
        # if image is None, the visualizer will use its default sample
        viz.visualize(images=image, add_grid=add_grid)
    finally:
        plt.show = orig_show

    left_img = _fig_to_pil(captured[0]) if len(captured) >= 1 else None
    right_img = _fig_to_pil(captured[1]) if len(captured) >= 2 else None
    prompt_preview = viz.default_message(full_output=False)
    return left_img, right_img, prompt_preview

def _resolve_and_run(model_pick, custom_model, image, add_grid):
    model_id = (custom_model or "").strip() or (model_pick or "").strip()
    if not model_id:
        raise gr.Error("Pick a model or enter one.")
    return _run(model_id, image, add_grid)

def _preload_models():
    for mid in MODELS:
        try:
            get_viz(mid)
        except Exception:
            pass

theme = gr.themes.Soft(primary_hue="orange", neutral_hue="gray")

with gr.Blocks(title="Transformers Processor Visualizer", theme=theme) as demo:
    gr.Markdown("## Visualize what a processor feeds a vision–text model")

    with gr.Row():
        with gr.Column(scale=1, min_width=280):
            model_pick = gr.Radio(label="Models", choices=MODELS, value=MODELS[0], interactive=True)
            custom_model = gr.Textbox(label="Or type a model id", placeholder="owner/repo", lines=1)
        with gr.Column(scale=3):
            with gr.Row():
                add_grid = gr.Checkbox(label="Show patch grid", value=True)
            image = gr.Image(label="Upload custom image", type="pil", height=140, sources=["upload"])
            gr.Markdown("## Output")
            with gr.Row():
                left_output = gr.Image(label="Processor output", type="pil", height=900)
                right_output = gr.Image(label="Global image (if any)", type="pil", height=900)
            prompt = gr.Textbox(label="Compact chat template preview", lines=2)

    # reactive updates
    model_pick.change(_resolve_and_run, [model_pick, custom_model, image, add_grid], [left_output, right_output, prompt])
    custom_model.submit(_resolve_and_run, [model_pick, custom_model, image, add_grid], [left_output, right_output, prompt])
    add_grid.change(_resolve_and_run, [model_pick, custom_model, image, add_grid], [left_output, right_output, prompt])
    image.change(_resolve_and_run, [model_pick, custom_model, image, add_grid], [left_output, right_output, prompt])

    # preload models into cache and render once
    demo.load(_preload_models, [], [])
    demo.load(_resolve_and_run, [model_pick, custom_model, image, add_grid], [left_output, right_output, prompt])

if __name__ == "__main__":
    demo.launch()